From 4eff3acd6a08e6e3b00fd14d098d34b87c80804b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 3 Jun 2026 00:38:44 +0100 Subject: [PATCH 1/7] Make the DAG processor access metadata exclusively through the API server The standalone DAG processor no longer connects to the metadata database. It persists parse results and reads all metadata through the DAG Processing API (AIP-92): a single DagProcessingApiClient routes persistence, bundle state and sync, stale-DAG and warning sweeps, priority-parse and callback claims, and the processor Job lifecycle, with no ORM session in the manager. Bundle-initialization connection/variable reads resolve through the Execution API (the same path workers and triggerers use), so a git connection stored in the metadata database keeps working without direct DB access. The processor parses user code, so it does not hold the signing key or mint its own token: it presents a bearer token the deployment provisions, read from [dag_processor] api_token_path, to both the DAG Processing and Execution APIs. In standalone (a trusted launcher that already holds the signing key), the token is minted and provisioned automatically. Per-loop API calls are guarded so a transient API outage skips a cycle instead of crashing the processor, the heartbeat is throttled, the client retries transient failures, and startup waits for API readiness. --- .../cli/commands/dag_processor_command.py | 53 +- .../cli/commands/standalone_command.py | 43 + .../src/airflow/config_templates/config.yml | 32 + .../src/airflow/dag_processing/api_client.py | 214 +++++ .../src/airflow/dag_processing/manager.py | 470 +++++------ .../cli/commands/test_standalone_command.py | 26 + .../tests/unit/dag_processing/test_manager.py | 799 ++---------------- .../test_manager_api_persistence.py | 317 +++++++ .../unit/dag_processing/test_no_db_mode.py | 128 +++ .../unit/dag_processing/test_processor.py | 39 +- 10 files changed, 1093 insertions(+), 1028 deletions(-) create mode 100644 airflow-core/src/airflow/dag_processing/api_client.py create mode 100644 airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py create mode 100644 airflow-core/tests/unit/dag_processing/test_no_db_mode.py diff --git a/airflow-core/src/airflow/cli/commands/dag_processor_command.py b/airflow-core/src/airflow/cli/commands/dag_processor_command.py index f4c303c278dbb..03e27436fe8df 100644 --- a/airflow-core/src/airflow/cli/commands/dag_processor_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_processor_command.py @@ -19,12 +19,13 @@ from __future__ import annotations import logging +import time from typing import Any from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.dag_processing.manager import DagFileProcessorManager from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner -from airflow.jobs.job import Job, run_job +from airflow.jobs.job import Job from airflow.utils import cli as cli_utils from airflow.utils.memray_utils import MemrayTraceComponents, enable_memray_trace from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -45,6 +46,52 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: ) +def _run_dag_processor_job(job_runner: DagProcessorJobRunner) -> None: + """ + Run the job, registering the liveness Job through the DAG Processing API. + + The DAG processor holds no metadata-DB connection, so its Job row is registered, + heartbeated, and completed through the DAG Processing API instead of writing the + database directly. The first call waits for the API server to be ready (the processor and + server may start concurrently), the heartbeat is throttled to the job heartrate and never + crashes the loop, and a failure to complete the Job does not mask a parsing-loop error. + """ + client = job_runner.processor._dag_processing_client + client.wait_until_ready() + job_id = client.register_job(job_runner.job_type) + + heartrate = getattr(job_runner.job, "heartrate", 5.0) or 5.0 + last_heartbeat = 0.0 + + def _heartbeat() -> None: + nonlocal last_heartbeat + now = time.monotonic() + if now - last_heartbeat < heartrate: + return + try: + client.job_heartbeat(job_id) + last_heartbeat = now + except Exception: + # Liveness update failure must not crash the processor; the next loop retries. + log.warning("DAG Processing API heartbeat failed; retrying next loop", exc_info=True) + + job_runner.processor.heartbeat = _heartbeat + state = "success" + try: + job_runner._execute() + except SystemExit: + pass + except Exception: + state = "failed" + raise + finally: + try: + client.complete_job(job_id, state=state) + except Exception: + # Don't let a completion failure mask the original parsing-loop exception. + log.warning("Failed to mark DAG processor Job %s as %s", job_id, state, exc_info=True) + + @enable_memray_trace(component=MemrayTraceComponents.dag_processor) @cli_utils.action_cli @providers_configuration_loaded @@ -56,7 +103,7 @@ def dag_processor(args): from airflow.cli.hot_reload import run_with_reloader run_with_reloader( - lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute), + lambda: _run_dag_processor_job(job_runner), process_name="dag-processor", ) return @@ -64,6 +111,6 @@ def dag_processor(args): run_command_with_daemon_option( args=args, process_name="dag-processor", - callback=lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute), + callback=lambda: _run_dag_processor_job(job_runner), should_setup_logging=True, ) diff --git a/airflow-core/src/airflow/cli/commands/standalone_command.py b/airflow-core/src/airflow/cli/commands/standalone_command.py index 2e94637e763c4..81f1215dd1b28 100644 --- a/airflow-core/src/airflow/cli/commands/standalone_command.py +++ b/airflow-core/src/airflow/cli/commands/standalone_command.py @@ -20,6 +20,7 @@ import os import socket import subprocess +import tempfile import threading import time from collections import deque @@ -189,8 +190,50 @@ def calculate_env(self): env["AIRFLOW__CORE__AUTH_MANAGER"] = simple_auth_manager_classpath os.environ["AIRFLOW__CORE__AUTH_MANAGER"] = simple_auth_manager_classpath # also in this process! + self._provision_dag_processor_token(env) + return env + def _provision_dag_processor_token(self, env: dict[str, str]) -> None: + """ + Mint the API token the DAG processor presents and point it at a file via env. + + The DAG processor parses user code, so it must not hold the signing key or mint its own + token. Standalone runs in a trusted context (it already holds the signing key the + scheduler uses to mint task tokens), so it mints the processor's token here and provisions + it through ``[dag_processor] api_token_path``. The token carries both the Execution and DAG + Processing audiences (the processor calls both APIs) and a sentinel subject, since the + Execution API is task-instance scoped. + """ + from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args, get_signing_key + + try: + # Materialise a signing key shared by every standalone subprocess, so the api-server + # validates the token with the same key it is minted with here. + secret = get_signing_key("api_auth", "jwt_secret", make_secret_key_if_needed=True) + env["AIRFLOW__API_AUTH__JWT_SECRET"] = secret + os.environ["AIRFLOW__API_AUTH__JWT_SECRET"] = secret + + audiences = [ + conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], + conf.get_mandatory_list_value("dag_processor", "jwt_audience")[0], + ] + token = JWTGenerator( + valid_for=conf.getint("execution_api", "jwt_expiration_time"), + # A JWT ``aud`` may be a list; the generator's hint is single-audience, but the + # processor presents this one token to both the Execution and DAG Processing APIs. + audience=audiences, # type: ignore[arg-type] + issuer=conf.get("api_auth", "jwt_issuer", fallback=None), + **get_signing_args(make_secret_key_if_needed=True), + ).generate({"sub": "00000000-0000-0000-0000-000000000000"}) + + fd, path = tempfile.mkstemp(prefix="airflow-standalone-", suffix=".token") + with os.fdopen(fd, "w") as token_file: + token_file.write(token) + env["AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH"] = path + except Exception as e: + self.print_output("standalone", f"Could not provision the DAG processor API token: {e}") + def find_user_info(self): if conf.get("core", "simple_auth_manager_all_admins").lower() == "true": # If we have no auth anyways, no need to print or do anything diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 74332719d0c55..5d7548a880bd2 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -511,6 +511,17 @@ core: type: string example: ~ default: ~ + dag_processing_api_server_url: + description: | + The url of the DAG processing api server. The DAG processor never accesses the metadata + database directly; it persists parse results and reads metadata through this api server. + Defaults to ``{BASE_URL}/dag-processing`` (a sibling of ``execution_api_server_url``), + where ``{BASE_URL}`` is the base url of the API server. Set this to point the DAG + processor at a different host. + version_added: 3.3.0 + type: string + example: "http://localhost:8080/dag-processing" + default: ~ multi_team: description: | Whether to run the Airflow environment in multi-team mode. @@ -2974,6 +2985,27 @@ dag_processor: example: "/tmp/some-place" default: ~ + jwt_audience: + version_added: 3.3.0 + description: | + The audience claim the DAG Processing API validates on tokens the DAG processor presents. + Can be a single value or a comma-separated string (all are accepted at validation time). + The token itself is minted by the deployment, not the processor; see ``api_token_path``. + example: "my-unique-airflow-id" + default: "urn:airflow.apache.org:dag-processing" + type: string + + api_token_path: + version_added: 3.3.0 + description: | + Path to a file containing the bearer token the DAG processor presents to the API server + (for both the DAG Processing and Execution APIs). The DAG processor parses user code, so + it must not hold the signing key or mint its own token; the deployment (or control plane) + provisions this token and writes it to the file. If unset, the processor sends no token. + example: "/var/run/airflow/dag-processor-token" + default: ~ + type: string + dag_bundle_config_list: description: | List of backend configs. Must supply name, classpath, and kwargs for each backend. diff --git a/airflow-core/src/airflow/dag_processing/api_client.py b/airflow-core/src/airflow/dag_processing/api_client.py new file mode 100644 index 0000000000000..f61c6c45a88cb --- /dev/null +++ b/airflow-core/src/airflow/dag_processing/api_client.py @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""HTTP client used by the DAG processor to persist parse results via the API (AIP-92).""" + +from __future__ import annotations + +import logging +import time +from datetime import datetime +from importlib import import_module +from typing import TYPE_CHECKING, Any +from urllib.parse import quote + +import httpx + +if TYPE_CHECKING: + from collections.abc import Iterable + + from airflow.dag_processing.processor import DagFileParsingResult + +log = logging.getLogger(__name__) + +# Connection-level retries are safe for every request (the request never reached the server), +# so they are applied uniformly by the transport. Application-level retries (5xx / read timeout) +# are only added for idempotent calls; claim/delete calls must not be retried after the request +# may have been processed server-side, or callbacks/priority-requests could be silently lost. +_CONNECT_RETRIES = 3 +_TRANSIENT_RETRIES = 3 +_TRANSIENT_BACKOFF = 0.5 + + +class DagProcessingApiClient: + """ + Forward DAG-processor persistence to the ``/dag-processing`` API sub-app. + + Replaces the DAG processor manager's direct metadata-DB writes: parse results + and stale reconciliation are sent over HTTP, so the processor process needs no + DB credentials for persistence. A single :class:`httpx.Client` is reused so + connections are pooled across the manager's parse loop. + """ + + def __init__(self, base_url: str, *, token: str | None = None, timeout: float = 30.0) -> None: + self._base_url = base_url.rstrip("/") + headers = {"Authorization": f"Bearer {token}"} if token else {} + # ``retries`` retries connection failures (request never sent) at the transport level. + self._client = httpx.Client( + headers=headers, + timeout=timeout, + transport=httpx.HTTPTransport(retries=_CONNECT_RETRIES), + ) + + def close(self) -> None: + self._client.close() + + def wait_until_ready(self, *, timeout: float = 60.0) -> None: + """ + Block until the DAG Processing API answers ``/health`` or ``timeout`` elapses. + + The DAG processor and API server may start concurrently (e.g. ``airflow standalone``); + this avoids crashing on a cold-start race before the server has bound its socket. + """ + deadline = time.monotonic() + timeout + delay = 0.5 + last_exc: Exception | None = None + while time.monotonic() < deadline: + try: + resp = self._client.get(f"{self._base_url}/health") + if resp.status_code < 500: + return + except httpx.HTTPError as exc: + last_exc = exc + time.sleep(delay) + delay = min(delay * 2, 5.0) + log.warning( + "DAG Processing API at %s not ready after %ss", self._base_url, timeout, exc_info=last_exc + ) + + def _send( + self, + method: str, + path: str, + *, + retry_transient: bool = True, + **kwargs: Any, + ) -> httpx.Response: + """ + Send a request, retrying transient failures. + + Connection failures are retried by the transport. When ``retry_transient`` is set + (idempotent calls only), read timeouts and 5xx responses are additionally retried with + backoff. Non-idempotent claim/delete calls pass ``retry_transient=False`` so a response + lost after the server processed the request does not cause a silent retry that drops rows. + """ + url = f"{self._base_url}{path}" + attempts = _TRANSIENT_RETRIES if retry_transient else 1 + for attempt in range(attempts): + try: + resp = self._client.request(method, url, **kwargs) + except httpx.TransportError: + if not retry_transient or attempt == attempts - 1: + raise + else: + if not (retry_transient and resp.status_code >= 500 and attempt < attempts - 1): + resp.raise_for_status() + return resp + time.sleep(_TRANSIENT_BACKOFF * (2**attempt)) + # Unreachable: the loop either returns or raises on the final attempt. + raise RuntimeError("unreachable") + + def persist_parsing_result( + self, + *, + bundle_name: str, + bundle_version: str | None, + version_data: dict | None, + parsing_result: DagFileParsingResult, + run_duration: float, + relative_fileloc: str | None, + ) -> None: + payload = { + "bundle_name": bundle_name, + "bundle_version": bundle_version, + "version_data": version_data, + "relative_fileloc": relative_fileloc, + "run_duration": run_duration, + "serialized_dags": [d.model_dump(mode="json") for d in parsing_result.serialized_dags], + "import_errors": parsing_result.import_errors, + "warnings": parsing_result.warnings or [], + } + self._send("POST", "/parsing-results", json=payload) + log.debug("Persisted parse result via API: bundle=%s file=%s", bundle_name, relative_fileloc) + + def reconcile(self, *, bundle_name: str, observed_filelocs: Iterable[str]) -> None: + payload = {"observed_filelocs": sorted(observed_filelocs)} + self._send("POST", f"/bundles/{quote(bundle_name, safe='')}/reconcile", json=payload) + log.debug("Reconciled bundle via API: bundle=%s", bundle_name) + + def get_bundle_state(self, bundle_name: str) -> dict | None: + """Return ``{last_refreshed, version}`` for a bundle, or ``None`` if it has no record.""" + resp = self._send("GET", f"/bundles/{quote(bundle_name, safe='')}/state") + data = resp.json() + if not data.get("found"): + return None + last_refreshed = data.get("last_refreshed") + return { + "last_refreshed": datetime.fromisoformat(last_refreshed) if last_refreshed else None, + "version": data.get("version"), + } + + def update_bundle_state(self, bundle_name: str, *, last_refreshed: datetime, version: str | None) -> None: + payload = {"last_refreshed": last_refreshed.isoformat(), "version": version} + self._send("PATCH", f"/bundles/{quote(bundle_name, safe='')}/state", json=payload) + + def sync_bundles(self) -> None: + self._send("POST", "/bundles/sync") + + def deactivate_stale_dags(self, *, stale_dag_threshold: int, last_parsed: list[dict]) -> None: + payload = {"stale_dag_threshold": stale_dag_threshold, "last_parsed": last_parsed} + self._send("POST", "/stale-dags", json=payload) + + def purge_inactive_dag_warnings(self) -> None: + self._send("POST", "/purge-warnings") + + def claim_priority_files(self, bundle_names: list[str]) -> list[dict]: + """Claim (select + delete) priority parse requests for the given bundles.""" + # Not idempotent (claim + delete): only the transport's connection-failure retry applies. + resp = self._send( + "POST", + "/priority-parse-requests/claim", + retry_transient=False, + json={"bundle_names": bundle_names}, + ) + return resp.json().get("claimed", []) + + def fetch_callbacks(self, *, bundle_names: list[str], limit: int) -> list: + """Claim callbacks via the API and rebuild typed ``CallbackRequest`` objects.""" + # Not idempotent (claim + delete): only the transport's connection-failure retry applies. + resp = self._send( + "POST", + "/callbacks/claim", + retry_transient=False, + json={"bundle_names": bundle_names, "limit": limit}, + ) + module = import_module("airflow.callbacks.callback_requests") + callbacks = [] + for data in resp.json().get("callbacks", []): + callback_request_class = getattr(module, data["req_class"]) + callbacks.append(callback_request_class.from_json(data["req_data"])) + return callbacks + + def register_job(self, job_type: str) -> int: + """Register the processor's liveness Job row server-side and return its id.""" + resp = self._send("POST", "/jobs", json={"job_type": job_type}) + return resp.json()["job_id"] + + def job_heartbeat(self, job_id: int) -> None: + self._send("POST", f"/jobs/{job_id}/heartbeat") + + def complete_job(self, job_id: int, *, state: str) -> None: + self._send("POST", f"/jobs/{job_id}/complete", json={"state": state}) diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index fc00b6730c8ea..83192f7592136 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -32,40 +32,35 @@ import zipfile from collections import OrderedDict, defaultdict from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import datetime from operator import attrgetter, itemgetter from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast import attrs import structlog -from sqlalchemy import select, update -from sqlalchemy.orm import load_only from tabulate import tabulate from uuid6 import uuid7 from airflow._shared.observability.metrics import stats from airflow._shared.observability.metrics.stats import normalize_name_for_stats from airflow._shared.timezones import timezone -from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.configuration import conf +from airflow.dag_processing.api_client import DagProcessingApiClient from airflow.dag_processing.bundles.base import ( BundleUsageTrackingManager, unpack_bundle_version, ) from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.dag_processing.collection import update_dag_parsing_results_in_db from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.exceptions import AirflowException -from airflow.models.asset import remove_references_to_deleted_dags -from airflow.models.dag import DagModel -from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.models.dagbundle import DagBundleModel -from airflow.models.dagwarning import DagWarning -from airflow.models.db_callback_request import DbCallbackRequest -from airflow.models.errors import ParseImportError +from airflow.executors.base_executor import get_execution_api_server_url from airflow.observability.metrics import stats_utils from airflow.sdk import SecretCache +from airflow.sdk.api.client import Client +from airflow.sdk.execution_time import task_runner +from airflow.sdk.execution_time.comms import GetConnection, GetVariable +from airflow.sdk.execution_time.request_handlers import handle_get_connection, handle_get_variable from airflow.sdk.log import init_log_file, logging_processors from airflow.typing_compat import assert_never from airflow.utils.file import list_py_file_paths, might_contain_dag @@ -74,20 +69,16 @@ from airflow.utils.process_utils import ( kill_child_processes_by_pids, ) -from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Sequence from socket import socket - from sqlalchemy.orm import Session - from sqlalchemy.sql import Select - from airflow.callbacks.callback_requests import CallbackRequest from airflow.dag_processing.bundles.base import BaseDagBundle - from airflow.sdk.api.client import Client + + +log = logging.getLogger(__name__) class DagParsingStat(NamedTuple): @@ -149,6 +140,74 @@ def _config_get_factory(section: str, key: str): return functools.partial(conf.get, section, key) +def _dag_processing_api_server_url() -> str: + """ + Resolve the DAG Processing API URL the processor persists through. + + Defaults to the ``/dag-processing`` mount on the configured API server (a sibling of the + ``/execution`` mount), so a standard deployment that already runs the API server needs no + extra configuration. Set ``[core] dag_processing_api_server_url`` to point at a different + host. + + The default is derived from the resolved Execution API URL (``execution_api_server_url``), + not just ``[api] base_url``, so a deployment that points the processor's Execution API at an + internal service URL keeps both sibling clients on the same host. + """ + explicit = conf.get("core", "dag_processing_api_server_url", fallback=None) + if explicit: + return explicit + execution_url = get_execution_api_server_url().rstrip("/") + if execution_url.endswith("/execution"): + return f"{execution_url[: -len('/execution')]}/dag-processing" + return f"{execution_url}/dag-processing" + + +def _api_token() -> str | None: + """ + Return the token the DAG processor presents to the API server, or ``None``. + + The DAG processor parses (and forks) user code, so it must not hold the deployment signing + key or be able to mint tokens. It only *carries* a token provisioned by a trusted component: + read from the file at ``[dag_processor] api_token_path`` (written by the deployment / control + plane). The same token authenticates both the DAG Processing and Execution API clients. The + issuance of that token is intentionally left to the deployment (the AIP-92 non-task principal + question); the processor never signs one itself. + """ + token_path = conf.get("dag_processor", "api_token_path", fallback=None) + if not token_path: + return None + try: + return Path(token_path).read_text().strip() or None + except OSError: + log.warning("Could not read the DAG processor API token from %s", token_path) + return None + + +class _DagProcessorSecretsComms: + """ + Minimal ``SUPERVISOR_COMMS`` for the DAG processor process. + + Bundle initialization (e.g. ``GitDagBundle`` resolving its git connection) runs in the + manager process, which holds no metadata-DB connection. Installing this as + ``task_runner.SUPERVISOR_COMMS`` makes ``ensure_secrets_backend_loaded()`` select + ``ExecutionAPISecretsBackend``; this shim then resolves the connection/variable lookups it + sends through the manager's remote Execution API client -- the same path the worker and + triggerer use -- so DB-stored bundle credentials resolve without direct DB access. + """ + + def __init__(self, client: Client) -> None: + self._client = client + + def send(self, msg: Any, **kwargs: Any) -> Any: + if isinstance(msg, GetConnection): + return handle_get_connection(self._client, msg)[0] + if isinstance(msg, GetVariable): + return handle_get_variable(self._client, msg)[0] + # Other messages (e.g. MaskSecret emitted while masking a fetched secret) do not apply to + # the manager process; ignore them. + return None + + def _resolve_path(instance: Any, attribute: attrs.Attribute, val: str | os.PathLike[str] | None): if val is not None: val = Path(val).resolve() @@ -275,8 +334,12 @@ class DagFileProcessorManager(LoggingMixin): factory=_config_get_factory("dag_processor", "file_parsing_sort_mode") ) - _api_server: InProcessExecutionAPI = attrs.field(init=False, factory=InProcessExecutionAPI) - """API server to interact with Metadata DB""" + _dag_processing_client: DagProcessingApiClient = attrs.field( + init=False, + factory=lambda: DagProcessingApiClient(_dag_processing_api_server_url(), token=_api_token()), + ) + """Client for the DAG Processing API. The DAG processor never reads or writes the metadata + database directly; all persistence and metadata reads are routed through the API server.""" def register_exit_signals(self): """Register signals that stop child processes.""" @@ -295,8 +358,8 @@ def _exit_gracefully(self, signum, frame): sys.exit(os.EX_OK) def sync_bundles(self) -> None: - """Sync configured DAG bundles to the metadata database.""" - DagBundlesManager().sync_bundles_to_db() + """Sync configured DAG bundles to the metadata database via the DAG Processing API.""" + self._dag_processing_client.sync_bundles() def get_all_bundles(self) -> list[BaseDagBundle]: """Return configured DAG bundles filtered by ``bundle_names_to_parse`` if provided.""" @@ -317,35 +380,44 @@ def run(self): def before_run(self) -> None: """Set up state required before the parsing loop starts. Default implementation; override to customize.""" - self.prepare_server_process_context() self.prepare_process_context() self.register_exit_signals() self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) + self._setup_secrets_comms() self.prepare_bundles() self._symlink_latest_log_directory() # To prevent COW in forked process parsing dag file gc.freeze() def after_run(self) -> None: - """Tear down state after the parsing loop exits. Default no-op; override to customize.""" + """Tear down state after the parsing loop exits.""" + # Drop the secrets shim installed in before_run so it does not outlive the loop. + if isinstance(getattr(task_runner, "SUPERVISOR_COMMS", None), _DagProcessorSecretsComms): + del task_runner.SUPERVISOR_COMMS - def prepare_server_process_context(self) -> None: + def _setup_secrets_comms(self) -> None: """ - Mark this process as running in "server" context so MetastoreBackend is available. + Route the manager's own connection/variable lookups through the remote Execution API. - Override to a no-op in subclasses that do not require direct DB access (e.g. API-backed - deployments under AIP-92). + Bundle initialization resolves credentials here (see :class:`_DagProcessorSecretsComms`). + Child parser processes reset ``SUPERVISOR_COMMS`` to their own comms in + ``_parse_file_entrypoint()``, so this only affects the manager process. """ - # TODO: Temporary until AIP-92 removes DB access from DagProcessorManager. - # The manager needs MetastoreBackend to retrieve connections from the database - # during bundle initialization (e.g., GitDagBundle.__init__ → GitHook needs git credentials). - # This marks the manager as "server" context so ensure_secrets_backend_loaded() provides - # MetastoreBackend instead of falling back to EnvironmentVariablesBackend only. - # Child parser processes explicitly override this by setting _AIRFLOW_PROCESS_CONTEXT=client - # in _parse_file_entrypoint() to prevent inheriting server privileges. - # Related: https://github.com/apache/airflow/pull/57459 - os.environ["_AIRFLOW_PROCESS_CONTEXT"] = "server" + try: + comms = _DagProcessorSecretsComms(self.client) + except Exception: + # Without a signing key the processor cannot authenticate to the Execution API; + # skip the shim so bundle init still resolves env/external-backend credentials + # instead of crashing. (DB-stored bundle credentials will not resolve in this case.) + self.log.warning( + "Could not initialize the Execution API client for bundle credentials; " + "bundle-init connections will resolve only from non-database secrets backends", + exc_info=True, + ) + return + # Duck-typed comms: the shim implements the .send() interface the secrets backend uses. + task_runner.SUPERVISOR_COMMS = comms # type: ignore[assignment] def prepare_process_context(self) -> None: """Initialize transport-neutral process state (selector, stats) before the parsing loop starts.""" @@ -406,62 +478,24 @@ def cleanup_stale_bundle_versions(self) -> None: """Clean up stale DAG bundle version usage records.""" BundleUsageTrackingManager().remove_stale_bundle_versions() - @provide_session - def deactivate_stale_dags( - self, - last_parsed: dict[DagFileInfo, datetime | None], - *, - session: Session = NEW_SESSION, - ): - """Detect and deactivate DAGs which are no longer present in files.""" - to_deactivate = set() - inactive_bundles = set( - session.scalars(select(DagBundleModel.name).where(DagBundleModel.active.is_(False))).all() - ) - query = select( - DagModel.dag_id, - DagModel.bundle_name, - DagModel.fileloc, - DagModel.last_parsed_time, - DagModel.relative_fileloc, - ).where(~DagModel.is_stale) - dags_parsed = session.execute(query) - - for dag in dags_parsed: - # Dags whose bundle has been removed from config (bundle no longer active) are stale — - # the processor has stopped parsing their files, so the time-based check below would never fire. - if dag.bundle_name in inactive_bundles: - self.log.info( - "Deactivating Dag %s. Its bundle %s is no longer active.", - dag.dag_id, - dag.bundle_name, - ) - to_deactivate.add(dag.dag_id) - continue - # When the Dag's last_parsed_time is more than the stale_dag_threshold older than the - # Dag file's last_finish_time, the Dag is considered stale as has apparently been removed from the file, - # This is especially relevant for Dag files that generate Dags in a dynamic manner. - file_info = DagFileInfo(rel_path=Path(dag.relative_fileloc), bundle_name=dag.bundle_name) - if last_finish_time := last_parsed.get(file_info, None): - if dag.last_parsed_time + timedelta(seconds=self.stale_dag_threshold) < last_finish_time: - self.log.info( - "Deactivating stale DAG %s. Not parsed for %s seconds (last parsed: %s).", - dag.dag_id, - int((last_finish_time - dag.last_parsed_time).total_seconds()), - dag.last_parsed_time, - ) - to_deactivate.add(dag.dag_id) - - if to_deactivate: - deactivated_dagmodel = session.execute( - update(DagModel) - .where(DagModel.dag_id.in_(to_deactivate)) - .values(is_stale=True) - .execution_options(synchronize_session="fetch") + def deactivate_stale_dags(self, last_parsed: dict[DagFileInfo, datetime | None]) -> None: + """Detect and deactivate DAGs which are no longer present in files, via the DAG Processing API.""" + entries = [ + { + "bundle_name": file_info.bundle_name, + "relative_fileloc": str(file_info.rel_path), + "last_finish_time": last_finish_time.isoformat(), + } + for file_info, last_finish_time in last_parsed.items() + if last_finish_time is not None + ] + try: + self._dag_processing_client.deactivate_stale_dags( + stale_dag_threshold=int(self.stale_dag_threshold), last_parsed=entries ) - deactivated = getattr(deactivated_dagmodel, "rowcount", 0) - if deactivated: - self.log.info("Deactivated %i DAGs which are no longer present in file.", deactivated) + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error deactivating stale DAGs via the DAG Processing API") def _run_parsing_loop(self): # initialize cache to mutualize calls to Variable.get in DAGs @@ -551,12 +585,23 @@ def _queue_requested_files_for_parsing(self) -> None: self.log.info("Bundles being force refreshed: %s", ", ".join(self._force_refresh_bundles)) def claim_priority_files(self) -> list[DagFileInfo]: - """ - Fetch and claim files requested for priority parsing. - - Default implementation reads from the metadata DB; override to source requests from an API. - """ - return self._claim_priority_files() + """Fetch and claim files requested for priority parsing, via the DAG Processing API.""" + bundles = {bundle.name: bundle for bundle in self._dag_bundles} + try: + claimed = self._dag_processing_client.claim_priority_files(list(bundles)) + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error claiming priority parse requests via the DAG Processing API") + return [] + return [ + DagFileInfo( + rel_path=Path(entry["relative_fileloc"]), + bundle_name=entry["bundle_name"], + bundle_path=bundles[entry["bundle_name"]].path, + ) + for entry in claimed + if entry["bundle_name"] in bundles + ] def request_bundle_refresh(self, bundle_names: str | Iterable[str]) -> None: """ @@ -587,66 +632,17 @@ def should_skip_refresh( and bundle.name not in self._force_refresh_bundles ) - @provide_session - def _claim_priority_files(self, *, session: Session = NEW_SESSION) -> list[DagFileInfo]: - """Fetch priority parsing requests from the metadata database.""" - files: list[DagFileInfo] = [] - bundles = {b.name: b for b in self._dag_bundles} - requests = session.scalars( - select(DagPriorityParsingRequest).where(DagPriorityParsingRequest.bundle_name.in_(bundles.keys())) - ) - for request in requests: - bundle = bundles[request.bundle_name] - files.append( - DagFileInfo( - rel_path=Path(request.relative_fileloc), bundle_name=bundle.name, bundle_path=bundle.path - ) - ) - session.delete(request) - return files - def fetch_callbacks(self) -> list[CallbackRequest]: - """ - Fetch and claim callbacks for this manager's bundles. - - Default implementation reads from the metadata DB; override to source callbacks from an API. - """ - return self._fetch_callbacks_from_db() - - @provide_session - @retry_db_transaction - def _fetch_callbacks_from_db( - self, - *, - session: Session = NEW_SESSION, - ) -> list[CallbackRequest]: - """Fetch callbacks from database and add them to the internal queue for execution.""" - self.log.debug("Fetching callbacks from the database.") - - callback_queue: list[CallbackRequest] = [] - with prohibit_commit(session) as guard: - bundle_names = [bundle.name for bundle in self._dag_bundles] - query: Select[tuple[DbCallbackRequest]] = with_row_locks( - select(DbCallbackRequest) - .where(DbCallbackRequest.bundle_name.in_(bundle_names)) - .order_by(DbCallbackRequest.priority_weight.desc()) - .limit(self.max_callbacks_per_loop), - of=DbCallbackRequest, - session=session, - skip_locked=True, + """Fetch and claim callbacks for this manager's bundles, via the DAG Processing API.""" + try: + return self._dag_processing_client.fetch_callbacks( + bundle_names=[bundle.name for bundle in self._dag_bundles], + limit=self.max_callbacks_per_loop, ) - callbacks: Sequence[DbCallbackRequest] = [ - cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query) - ] - for callback in callbacks: - req = callback.get_callback_request() - try: - callback_queue.append(req) - session.delete(callback) - except Exception as e: - self.log.warning("Error adding callback for execution: %s, %s", callback, e) - guard.commit() - return callback_queue + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error fetching callbacks via the DAG Processing API") + return [] def prepare_callback_bundle(self, request: CallbackRequest) -> BaseDagBundle | None: """ @@ -689,51 +685,35 @@ def _add_callback_to_queue(self, request: CallbackRequest) -> None: self._add_files_to_queue([file_info], mode="front") stats.incr("dag_processing.other_callback_count") - @provide_session - def get_bundle_state(self, bundle_name: str, *, session: Session = NEW_SESSION) -> BundleState | None: + def get_bundle_state(self, bundle_name: str) -> BundleState | None: """ - Return the persisted refresh state for a bundle. + Return the persisted refresh state for a bundle, via the DAG Processing API. - Returns ``None`` if the bundle has no database record. + Returns ``None`` if the bundle has no record. """ - row = session.scalar( - select(DagBundleModel) - .where(DagBundleModel.name == bundle_name) - .options(load_only(DagBundleModel.last_refreshed, DagBundleModel.version)) - ) - if row is None: + data = self._dag_processing_client.get_bundle_state(bundle_name) + if data is None: return None - return BundleState(last_refreshed=row.last_refreshed, version=row.version) + return BundleState(last_refreshed=data["last_refreshed"], version=data["version"]) - @provide_session - def update_bundle_state( - self, - bundle_name: str, - *, - last_refreshed: datetime, - version: str | None, - session: Session = NEW_SESSION, - ) -> None: + def update_bundle_state(self, bundle_name: str, *, last_refreshed: datetime, version: str | None) -> None: """ - Persist the post-refresh state for a bundle. + Persist the post-refresh state for a bundle, via the DAG Processing API. - Always updates ``last_refreshed``. Updates ``version`` only when ``version`` is not - ``None`` — pass ``None`` to leave the stored version unchanged (e.g. for non-versioned - bundles or when the version did not change after a refresh). + Always updates ``last_refreshed``; updates ``version`` only when ``version`` is not + ``None`` (pass ``None`` to leave the stored version unchanged). """ - values: dict[str, Any] = {"last_refreshed": last_refreshed} - if version is not None: - values["version"] = version - session.execute(update(DagBundleModel).where(DagBundleModel.name == bundle_name).values(**values)) + self._dag_processing_client.update_bundle_state( + bundle_name, last_refreshed=last_refreshed, version=version + ) def purge_inactive_dag_warnings(self) -> None: - """ - Purge warnings for inactive/stale DAGs. - - Default implementation deletes records from the metadata DB; override to - source warnings from an API or skip the cleanup entirely. - """ - DagWarning.purge_inactive_dag_warnings() + """Purge warnings for inactive/stale DAGs, via the DAG Processing API.""" + try: + self._dag_processing_client.purge_inactive_dag_warnings() + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error purging inactive DAG warnings via the DAG Processing API") def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]): """Refresh DAG bundles, if required.""" @@ -846,11 +826,17 @@ def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]): known_files[bundle.name] = found_files - self.deactivate_deleted_dags(bundle_name=bundle.name, present=found_files) - self.clear_orphaned_import_errors( - bundle_name=bundle.name, - observed_filelocs=self._get_observed_filelocs(found_files), - ) + try: + self._dag_processing_client.reconcile( + bundle_name=bundle.name, + observed_filelocs=self._get_observed_filelocs(found_files), + ) + except Exception: + self.log.exception( + "Error reconciling bundle %s via the DAG Processing API; " + "skipping stale reconciliation this cycle", + bundle.name, + ) if any_refreshed: self.handle_removed_files(known_files=known_files) @@ -899,20 +885,6 @@ def find_zipped_dags(abs_path: os.PathLike) -> Iterator[str]: return observed_filelocs - def deactivate_deleted_dags(self, bundle_name: str, present: set[DagFileInfo]) -> None: - """Deactivate DAGs that come from files that are no longer present in bundle.""" - observed_filelocs = self._get_observed_filelocs(present) - with create_session() as session: - any_deactivated = DagModel.deactivate_deleted_dags( - bundle_name=bundle_name, - rel_filelocs=observed_filelocs, - session=session, - ) - # Only run cleanup if we actually deactivated any DAGs - # This avoids unnecessary DELETE queries in the common case where no DAGs were deleted - if any_deactivated: - remove_references_to_deleted_dags(session=session) - def print_stats(self, known_files: dict[str, set[DagFileInfo]]): """Occasionally print out stats about how fast the files are getting processed.""" if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time: @@ -920,28 +892,6 @@ def print_stats(self, known_files: dict[str, set[DagFileInfo]]): self._log_file_processing_stats(known_files=known_files) self.last_stat_print_time = time.monotonic() - @provide_session - def clear_orphaned_import_errors( - self, bundle_name: str, observed_filelocs: set[str], *, session: Session = NEW_SESSION - ): - """ - Clear import errors for files that no longer exist. - - :param session: session for ORM operations - """ - self.log.debug("Removing old import errors") - try: - errors = session.scalars( - select(ParseImportError) - .where(ParseImportError.bundle_name == bundle_name) - .options(load_only(ParseImportError.filename)) - ) - for error in errors: - if error.filename not in observed_filelocs: - session.delete(error) - except Exception: - self.log.exception("Error removing old import errors") - def _log_file_processing_stats(self, known_files: dict[str, set[DagFileInfo]]): """ Print out stats about how files are getting processed. @@ -1091,13 +1041,10 @@ def terminate_orphan_processes(self, present: set[DagFileInfo]): processor.logger_filehandle.close() self._file_stats.pop(file, None) - @provide_session def handle_parsing_result( self, file: DagFileInfo, proc: DagFileProcessorProcess, - *, - session: Session = NEW_SESSION, ) -> None: """ Post-process a single finished parse result. @@ -1107,11 +1054,6 @@ def handle_parsing_result( Extracted from ``_collect_results`` to keep result handling and persistence separate. - Owns its own DB session via ``@provide_session`` so subclasses that - forward results without touching the metadata DB (e.g. AIP-92 API-backed - deployments) can override this method without inheriting a session - created by the caller. - If persistence fails, the error is logged and the previous persisted DAG/import-error counts are preserved while a minimal timestamp update throttles immediate retries, so other files in the same @@ -1142,7 +1084,6 @@ def handle_parsing_result( parsing_result=proc.parsing_result, run_duration=run_duration, relative_fileloc=str(file.rel_path), - session=session, ) except Exception: self.log.exception( @@ -1174,36 +1115,15 @@ def persist_parsing_result( parsing_result: DagFileParsingResult, run_duration: float, relative_fileloc: str | None, - session: Session, ) -> None: - """Persist parsed DAG data to the metadata database.""" - import_errors: dict[tuple[str, str], str] = {} - if parsing_result.import_errors: - import_errors = { - (bundle_name, rel_path): error for rel_path, error in parsing_result.import_errors.items() - } - - # Build the set of files that were parsed. This includes the file that was parsed, - # even if it no longer contains DAGs, so we can clear old import errors. - files_parsed: set[tuple[str, str]] | None = None - if relative_fileloc is not None: - files_parsed = {(bundle_name, relative_fileloc)} - files_parsed.update(import_errors.keys()) - - warnings = parsing_result.warnings or [] - if warnings and isinstance(warnings[0], dict): - warnings = [DagWarning(**warn) for warn in warnings] - - update_dag_parsing_results_in_db( + """Persist parsed DAG data via the DAG Processing API.""" + self._dag_processing_client.persist_parsing_result( bundle_name=bundle_name, bundle_version=bundle_version, version_data=version_data, - dags=parsing_result.serialized_dags, - import_errors=import_errors, - parse_duration=run_duration, - warnings=set(warnings), - session=session, - files_parsed=files_parsed, + parsing_result=parsing_result, + run_duration=run_duration, + relative_fileloc=relative_fileloc, ) def _collect_results(self): @@ -1269,12 +1189,14 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo): @functools.cached_property def client(self) -> Client: - from airflow.sdk.api.client import Client - - client = Client(base_url=None, token="", dry_run=True, transport=self._api_server.transport) - # Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str` - client.base_url = "http://in-process.invalid./" - return client + # Parse-time connection/variable/xcom reads go to the remote Execution API, so the + # processor holds no metadata-DB connection. It carries the externally-provisioned token + # (see _api_token); it does not mint one, since it parses user code. + return Client( + base_url=get_execution_api_server_url(), + token=_api_token() or "", + dry_run=False, + ) def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: id = uuid7() diff --git a/airflow-core/tests/unit/cli/commands/test_standalone_command.py b/airflow-core/tests/unit/cli/commands/test_standalone_command.py index 993f51ca2d72d..8bb2b56b4e1ce 100644 --- a/airflow-core/tests/unit/cli/commands/test_standalone_command.py +++ b/airflow-core/tests/unit/cli/commands/test_standalone_command.py @@ -20,6 +20,7 @@ import os from collections import deque from importlib import reload +from pathlib import Path from unittest import mock import pytest @@ -32,6 +33,8 @@ LOCAL_EXECUTOR, ) +from tests_common.test_utils.config import conf_vars + class TestStandaloneCommand: @pytest.mark.parametrize( @@ -82,6 +85,29 @@ class FakeExecutor: assert "AIRFLOW__CORE__AUTH_MANAGER" not in env + @mock.patch("airflow.api_fastapi.auth.tokens.JWTGenerator") + @mock.patch("airflow.api_fastapi.auth.tokens.get_signing_args", return_value={"secret_key": "s"}) + @mock.patch("airflow.api_fastapi.auth.tokens.get_signing_key", return_value="the-secret") + def test_provision_dag_processor_token(self, _get_key, _get_args, mock_generator): + """Standalone mints the processor's token and provisions it via env + a token file.""" + mock_generator.return_value.generate.return_value = "minted-token" + env: dict[str, str] = {} + with conf_vars( + { + ("execution_api", "jwt_audience"): "exec-aud", + ("dag_processor", "jwt_audience"): "dag-proc-aud", + } + ): + StandaloneCommand()._provision_dag_processor_token(env) + + # The processor never mints: standalone provisions the token + a shared signing key. + assert env["AIRFLOW__API_AUTH__JWT_SECRET"] == "the-secret" + token_path = env["AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH"] + assert Path(token_path).read_text() == "minted-token" + # Token is minted for both the Execution and DAG Processing audiences. + assert mock_generator.call_args.kwargs["audience"] == ["exec-aud", "dag-proc-aud"] + Path(token_path).unlink() + @mock.patch("airflow.cli.commands.standalone_command.os.path.exists", return_value=False) @mock.patch("airflow.cli.commands.standalone_command.create_auth_manager") @mock.patch("airflow.cli.commands.standalone_command.conf.get") diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 941d090f22994..24fe6c5b7392a 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -23,7 +23,6 @@ import os import random import re -import shutil import signal import textwrap import time @@ -38,14 +37,13 @@ import msgspec import pytest import time_machine -from sqlalchemy import func, select +from sqlalchemy import select, update from uuid6 import uuid7 from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.dag_processing.bundles.base import BaseDagBundle from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.dag_processing.dagbag import DagBag from airflow.dag_processing.manager import ( BundleState, DagFileInfo, @@ -53,19 +51,13 @@ DagFileStat, ) from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess -from airflow.models import DagModel, DbCallbackRequest -from airflow.models.asset import TaskOutletAssetReference -from airflow.models.dag_version import DagVersion +from airflow.models import DagModel from airflow.models.dagbundle import DagBundleModel -from airflow.models.dagcode import DagCode -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.team import Team from airflow.utils.net import get_hostname from airflow.utils.session import create_session -from tests_common.test_utils.compat import ParseImportError from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import ( clear_db_assets, clear_db_callbacks, @@ -162,6 +154,52 @@ def _disable_examples(self): with conf_vars({("core", "load_examples"): "False"}): yield + @pytest.fixture(autouse=True) + def mock_dag_processing_client(self): + """Patch the DAG Processing API client so managers don't hit a real API server. + + ``DagFileProcessorManager`` now always holds a ``DagProcessingApiClient`` and routes + every persistence/metadata operation through it; in unit tests there is no API server, + so the real client raises ``httpx.ConnectError``. This replaces it with a mock whose + bundle-state read/write still round-trips through the metadata DB (so refresh-mechanic + tests behave as before), while persistence/reconcile/claim calls are inert. + + ``get_bundle_state`` falls back to a stale state when no ``DagBundleModel`` row exists. + In production the server creates that row during ``sync_bundles`` before parsing; here + ``sync_bundles`` is inert, so without the fallback every bundle would be skipped with + "Bundle model not found" and no files would ever be scanned. + + The mock is yielded so individual tests can override return values (e.g. + ``client.get_bundle_state.return_value = {"last_refreshed": ..., "version": ...}``). + """ + # A "stale" timestamp so bundles with refresh_interval=0 always refresh by default. + stale_last_refreshed = timezone.datetime(2000, 1, 1) + + def _db_get_bundle_state(bundle_name): + with create_session() as session: + row = session.scalar(select(DagBundleModel).where(DagBundleModel.name == bundle_name)) + if row is None: + return {"last_refreshed": stale_last_refreshed, "version": None} + return {"last_refreshed": row.last_refreshed, "version": row.version} + + def _db_update_bundle_state(bundle_name, *, last_refreshed, version): + values: dict = {"last_refreshed": last_refreshed} + if version is not None: + values["version"] = version + with create_session() as session: + session.execute( + update(DagBundleModel).where(DagBundleModel.name == bundle_name).values(**values) + ) + session.commit() + + with mock.patch("airflow.dag_processing.manager.DagProcessingApiClient") as mock_client_cls: + client = mock_client_cls.return_value + client.get_bundle_state.side_effect = _db_get_bundle_state + client.update_bundle_state.side_effect = _db_update_bundle_state + client.claim_priority_files.return_value = [] + client.fetch_callbacks.return_value = [] + yield client + def setup_method(self): clear_db_teams() clear_db_assets() @@ -204,76 +242,6 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces ret._open_sockets.clear() return ret, read_end - @pytest.fixture - def clear_parse_import_errors(self): - clear_db_import_errors() - - @pytest.mark.usefixtures("clear_parse_import_errors") - @conf_vars({("core", "load_examples"): "False"}) - def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_bundle): - path_to_parse = tmp_path / "temp_dag.py" - - # Generate original import error - path_to_parse.write_text("an invalid airflow DAG") - - with configure_testing_dag_bundle(path_to_parse): - manager = DagFileProcessorManager( - max_runs=1, - processor_timeout=365 * 86_400, - ) - - manager.run() - - with create_session() as session: - import_errors = session.scalars(select(ParseImportError)).all() - assert len(import_errors) == 1 - - path_to_parse.unlink() - - # Rerun the parser once the dag file has been removed - manager.run() - - with create_session() as session: - import_errors = session.scalars(select(ParseImportError)).all() - - assert len(import_errors) == 0 - session.rollback() - - @pytest.mark.usefixtures("clear_parse_import_errors") - def test_clear_orphaned_import_errors_keeps_zip_inner_file_errors(self, session, tmp_path): - zip_path = tmp_path / "test_zip.zip" - _create_zip_bundle_with_valid_and_broken_dags(zip_path) - - session.add( - ParseImportError( - filename="test_zip.zip/broken_dag.py", - bundle_name="testing", - timestamp=timezone.utcnow(), - stacktrace="zip import error", - ) - ) - session.flush() - - manager = DagFileProcessorManager(max_runs=1) - manager.clear_orphaned_import_errors( - bundle_name="testing", - observed_filelocs=manager._get_observed_filelocs( - { - DagFileInfo( - bundle_name="testing", - rel_path=Path("test_zip.zip"), - bundle_path=tmp_path, - ) - } - ), - session=session, - ) - session.flush() - - import_errors = session.scalars(select(ParseImportError)).all() - assert len(import_errors) == 1 - assert import_errors[0].filename == "test_zip.zip/broken_dag.py" - def test_get_observed_filelocs_expands_zip_inner_paths(self, tmp_path): zip_path = tmp_path / "test_zip.zip" _create_zip_bundle_with_valid_and_broken_dags(zip_path) @@ -294,63 +262,24 @@ def test_get_observed_filelocs_expands_zip_inner_paths(self, tmp_path): "test_zip.zip/broken_dag.py", } - @pytest.mark.usefixtures("clear_parse_import_errors") - def test_refresh_dag_bundles_keeps_zip_inner_file_errors(self, session, tmp_path, configure_dag_bundles): - bundle_path = tmp_path / "bundleone" - bundle_path.mkdir() - zip_path = bundle_path / "test_zip.zip" - _create_zip_bundle_with_valid_and_broken_dags(zip_path) - - session.add( - ParseImportError( - filename="test_zip.zip/broken_dag.py", - bundle_name="bundleone", - timestamp=timezone.utcnow(), - stacktrace="zip import error", - ) - ) - session.flush() - - with configure_dag_bundles({"bundleone": bundle_path}): - DagBundlesManager().sync_bundles_to_db() - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - manager._refresh_dag_bundles({}) - - import_errors = session.scalars(select(ParseImportError)).all() - assert len(import_errors) == 1 - assert import_errors[0].filename == "test_zip.zip/broken_dag.py" - - def test_refresh_dag_bundles_calls_legacy_deactivate_deleted_dags_override( - self, tmp_path, configure_dag_bundles + def test_refresh_dag_bundles_reconciles_via_client( + self, tmp_path, configure_dag_bundles, mock_dag_processing_client ): bundle_path = tmp_path / "bundleone" bundle_path.mkdir() dag_path = bundle_path / "test_dag.py" dag_path.write_text("from airflow.sdk import DAG\n") - class BackwardCompatibleManager(DagFileProcessorManager): - seen_bundle_name: str | None = None - seen_present: set[DagFileInfo] | None = None - - def deactivate_deleted_dags(self, bundle_name: str, present: set[DagFileInfo]) -> None: - self.seen_bundle_name = bundle_name - self.seen_present = present - with configure_dag_bundles({"bundleone": bundle_path}): DagBundlesManager().sync_bundles_to_db() - manager = BackwardCompatibleManager(max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) manager._refresh_dag_bundles({}) - assert manager.seen_bundle_name == "bundleone" - assert manager.seen_present == { - DagFileInfo( - bundle_name="bundleone", - rel_path=Path("test_dag.py"), - bundle_path=bundle_path, - ) - } + mock_dag_processing_client.reconcile.assert_called_once_with( + bundle_name="bundleone", + observed_filelocs={"test_dag.py"}, + ) @conf_vars({("core", "load_examples"): "False"}) def test_max_runs_when_no_files(self, tmp_path): @@ -847,53 +776,6 @@ def test_recently_modified_file_uses_versioned_stats_without_creating_duplicate_ assert known_file not in manager._file_stats assert versioned_file in manager._file_stats - def test_file_paths_in_queue_sorted_by_priority(self): - from airflow.models.dagbag import DagPriorityParsingRequest - - parsing_request = DagPriorityParsingRequest(relative_fileloc="file_1.py", bundle_name="dags-folder") - with create_session() as session: - session.add(parsing_request) - session.commit() - - file1 = DagFileInfo( - bundle_name="dags-folder", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER - ) - file2 = DagFileInfo( - bundle_name="dags-folder", rel_path=Path("file_2.py"), bundle_path=TEST_DAGS_FOLDER - ) - - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - manager._file_queue = OrderedDict.fromkeys([file2, file1]) - manager._queue_requested_files_for_parsing() - assert manager._file_queue == OrderedDict.fromkeys([file1, file2]) - assert manager._force_refresh_bundles == {"dags-folder"} - with create_session() as session2: - parsing_request_after = session2.get(DagPriorityParsingRequest, parsing_request.id) - assert parsing_request_after is None - - def test_parsing_requests_only_bundles_being_parsed(self, testing_dag_bundle): - """Ensure the manager only handles parsing requests for bundles being parsed in this manager""" - from airflow.models.dagbag import DagPriorityParsingRequest - - with create_session() as session: - session.add(DagPriorityParsingRequest(relative_fileloc="file_1.py", bundle_name="dags-folder")) - session.add(DagPriorityParsingRequest(relative_fileloc="file_x.py", bundle_name="testing")) - session.commit() - - file1 = DagFileInfo( - bundle_name="dags-folder", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER - ) - - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - manager._queue_requested_files_for_parsing() - assert manager._file_queue == OrderedDict.fromkeys([file1]) - with create_session() as session2: - parsing_request_after = session2.scalars(select(DagPriorityParsingRequest)).all() - assert len(parsing_request_after) == 1 - assert parsing_request_after[0].relative_fileloc == "file_x.py" - def test_queue_requested_files_for_parsing_uses_public_claim_hook(self): file1 = DagFileInfo( bundle_name="dags-folder", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER @@ -933,10 +815,11 @@ def test_request_bundle_refresh_accepts_single_bundle_name(self): assert manager._force_refresh_bundles == {"bundleone"} @pytest.mark.usefixtures("testing_dag_bundle") - def test_scan_stale_dags(self, session): + def test_scan_stale_dags(self, mock_dag_processing_client): """ - Ensure that DAGs are marked inactive when the file is parsed but the - DagModel.last_parsed_time is not updated. + Ensure ``_scan_stale_dags`` forwards the parsed files to the DAG Processing API so the + server can deactivate DAGs whose files are no longer present. The actual stale-marking + now happens server-side. """ manager = DagFileProcessorManager( max_runs=1, @@ -951,21 +834,13 @@ def test_scan_stale_dags(self, session): rel_path=Path("test_example_bash_operator.py"), bundle_path=TEST_DAGS_FOLDER, ) - dagbag = DagBag( - test_dag_path.absolute_path, - include_examples=False, - bundle_path=test_dag_path.bundle_path, - ) - - # Add stale DAG to the DB - dag = dagbag.get_dag("test_example_bash_operator") - sync_dag_to_db(dag, session=session) # Add DAG to the file_parsing_stats + finish_time = timezone.utcnow() + timedelta(hours=1) stat = DagFileStat( num_dags=1, import_errors=0, - last_finish_time=timezone.utcnow() + timedelta(hours=1), + last_finish_time=finish_time, last_duration=1, run_count=1, last_num_of_db_queries=1, @@ -973,72 +848,19 @@ def test_scan_stale_dags(self, session): manager._files = [test_dag_path] manager._file_stats[test_dag_path] = stat - active_dag_count = session.scalar( - select(func.count(DagModel.dag_id)).where( - ~DagModel.is_stale, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) - ) - assert active_dag_count == 1 - + # Force the cleanup interval so the scan actually runs this cycle. + manager._last_deactivate_stale_dags_time = time.monotonic() - (manager.parsing_cleanup_interval + 1) manager._scan_stale_dags() - active_dag_count = session.scalar( - select(func.count(DagModel.dag_id)).where( - ~DagModel.is_stale, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) - ) - assert active_dag_count == 0 - - serialized_dag_count = session.scalar( - select(func.count(SerializedDagModel.dag_id)).where(SerializedDagModel.dag_id == dag.dag_id) - ) - # Deactivating the DagModel should not delete the SerializedDagModel - # SerializedDagModel gives history about Dags - assert serialized_dag_count == 1 - - @pytest.mark.usefixtures("testing_dag_bundle") - def test_deactivate_stale_dags_marks_dags_in_inactive_bundles(self, session): - """Dags whose bundle is no longer active should be marked stale even without a parse signal.""" - session.add(DagBundleModel(name="gone-bundle")) - session.flush() - session.execute( - DagBundleModel.__table__.update().where(DagBundleModel.name == "gone-bundle").values(active=False) - ) - session.add( - DagModel( - dag_id="dag_in_inactive_bundle", - bundle_name="gone-bundle", - relative_fileloc="some_file.py", - last_parsed_time=timezone.utcnow(), - is_stale=False, - ) - ) - session.add( - DagModel( - dag_id="dag_in_active_bundle", - bundle_name="testing", - relative_fileloc="other_file.py", - last_parsed_time=timezone.utcnow(), - is_stale=False, - ) - ) - session.flush() - - manager = DagFileProcessorManager(max_runs=1, processor_timeout=10 * 60) - manager.deactivate_stale_dags(last_parsed={}) - - is_stale_by_dag = dict( - session.execute( - select(DagModel.dag_id, DagModel.is_stale).where( - DagModel.dag_id.in_(["dag_in_inactive_bundle", "dag_in_active_bundle"]) - ) - ).all() - ) - assert is_stale_by_dag == {"dag_in_inactive_bundle": True, "dag_in_active_bundle": False} + mock_dag_processing_client.deactivate_stale_dags.assert_called_once() + call_kwargs = mock_dag_processing_client.deactivate_stale_dags.call_args.kwargs + assert call_kwargs["last_parsed"] == [ + { + "bundle_name": "testing", + "relative_fileloc": "test_example_bash_operator.py", + "last_finish_time": finish_time.isoformat(), + } + ] @mock.patch("airflow.dag_processing.manager.BundleUsageTrackingManager") def test_cleanup_stale_bundle_versions_interval(self, mock_bundle_manager): @@ -1119,24 +941,7 @@ def test_kill_timed_out_processors_no_kill(self): manager._kill_timed_out_processors() mock_kill.assert_not_called() - def test_handle_parsing_result_provides_its_own_session_when_caller_omits(self): - """``handle_parsing_result`` is wrapped in ``@provide_session`` so subclasses overriding it can run without a caller-supplied session.""" - manager = DagFileProcessorManager(max_runs=1) - file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) - manager._file_stats[file] = DagFileStat() - manager._bundle_versions["testing"] = "v1" - - processor, _ = self.mock_processor(start_time=time.monotonic() - 1) - processor.had_callbacks = False - processor.parsing_result = DagFileParsingResult(fileloc="abc.txt", serialized_dags=[]) - - with mock.patch.object(manager, "persist_parsing_result") as mock_persist: - manager.handle_parsing_result(file, processor) - - mock_persist.assert_called_once() - assert mock_persist.call_args.kwargs["session"] is not None - - def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self, session): + def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self): """Persist errors should throttle retries without claiming persistence succeeded.""" manager = DagFileProcessorManager(max_runs=1) file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) @@ -1150,7 +955,7 @@ def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self, se processor.parsing_result = DagFileParsingResult(fileloc="abc.txt", serialized_dags=[]) with mock.patch.object(manager, "persist_parsing_result", side_effect=RuntimeError("boom")): - manager.handle_parsing_result(file, processor, session=session) + manager.handle_parsing_result(file, processor) assert manager._file_stats[file] is not original_stat assert manager._file_stats[file].num_dags == 0 @@ -1160,7 +965,7 @@ def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self, se assert manager._file_stats[file].last_duration is not None assert manager.processed_recently(timezone.utcnow(), file) is True - def test_handle_parsing_result_updates_stats_after_successful_persist(self, session): + def test_handle_parsing_result_updates_stats_after_successful_persist(self): manager = DagFileProcessorManager(max_runs=1) file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) original_stat = DagFileStat( @@ -1179,7 +984,7 @@ def test_handle_parsing_result_updates_stats_after_successful_persist(self, sess processor.parsing_result = DagFileParsingResult(fileloc="abc.txt", serialized_dags=[]) with mock.patch.object(manager, "persist_parsing_result") as mock_persist: - manager.handle_parsing_result(file, processor, session=session) + manager.handle_parsing_result(file, processor) mock_persist.assert_called_once_with( bundle_name="testing", @@ -1188,7 +993,6 @@ def test_handle_parsing_result_updates_stats_after_successful_persist(self, sess parsing_result=processor.parsing_result, run_duration=mock.ANY, relative_fileloc="abc.txt", - session=session, ) assert manager._file_stats[file] is not original_stat assert manager._file_stats[file].run_count == 4 @@ -1311,23 +1115,15 @@ def test_dag_with_system_exit(self, configure_testing_dag_bundle): """ Test to check that a DAG with a system.exit() doesn't break the scheduler. """ - dag_id = "exit_test_dag" dag_directory = TEST_DAG_FOLDER.parent / "dags_with_system_exit" - # Delete the one valid DAG/SerializedDAG, and check that it gets re-created - clear_db_dags() - clear_db_serialized_dags() - with configure_testing_dag_bundle(dag_directory): manager = DagFileProcessorManager(max_runs=1) manager.run() - # Three files in folder should be processed + # The system-exit DAG must not abort the loop: all three files still get processed. assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 - with create_session() as session: - assert session.get(DagModel, dag_id) is not None - @conf_vars({("core", "load_examples"): "False"}) @mock.patch("airflow.dag_processing.manager.stats.timing") def test_send_file_processing_statsd_timing( @@ -1361,84 +1157,6 @@ def test_send_file_processing_statsd_timing( ) @pytest.mark.usefixtures("testing_dag_bundle") - def test_refresh_dags_dir_doesnt_delete_zipped_dags( - self, tmp_path, session, configure_testing_dag_bundle, test_zip_path - ): - """Test DagFileProcessorManager._refresh_dag_dir method""" - dagbag = DagBag(dag_folder=tmp_path, include_examples=False) - dagbag.process_file(test_zip_path) - dag = dagbag.get_dag("test_zip_dag") - sync_dag_to_db(dag) - - with configure_testing_dag_bundle(test_zip_path): - manager = DagFileProcessorManager(max_runs=1) - manager.run() - - # Assert dag not deleted in SDM - assert SerializedDagModel.has_dag("test_zip_dag") - # assert code not deleted - assert DagCode.has_dag(dag.dag_id) - # assert dag still active - assert session.get(DagModel, dag.dag_id).is_stale is False - - @pytest.mark.usefixtures("testing_dag_bundle") - def test_refresh_dags_dir_deactivates_deleted_zipped_dags( - self, session, tmp_path, configure_testing_dag_bundle, test_zip_path - ): - """Test DagFileProcessorManager._refresh_dag_dir method""" - dag_id = "test_zip_dag" - filename = "test_zip.zip" - source_location = test_zip_path - bundle_path = Path(tmp_path, "test_refresh_dags_dir_deactivates_deleted_zipped_dags") - bundle_path.mkdir(exist_ok=True) - zip_dag_path = bundle_path / filename - shutil.copy(source_location, zip_dag_path) - - with configure_testing_dag_bundle(bundle_path): - session.commit() - manager = DagFileProcessorManager(max_runs=1) - manager.run() - - assert SerializedDagModel.has_dag(dag_id) - assert DagCode.has_dag(dag_id) - assert DagVersion.get_latest_version(dag_id) - dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) - assert dag.is_stale is False - - os.remove(zip_dag_path) - - manager.run() - - assert SerializedDagModel.has_dag(dag_id) - assert DagCode.has_dag(dag_id) - assert DagVersion.get_latest_version(dag_id) - dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) - assert dag.is_stale is True - - def test_deactivate_deleted_dags(self, dag_maker, session): - with dag_maker("test_dag1") as dag1: - dag1.relative_fileloc = "test_dag1.py" - with dag_maker("test_dag2") as dag2: - dag2.relative_fileloc = "test_dag2.py" - dag_maker.sync_dagbag_to_db() - - active_files = [ - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag1.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - # Mimic that the test_dag2.py file is deleted - ] - - manager = DagFileProcessorManager(max_runs=1) - manager.deactivate_deleted_dags("dag_maker", set(active_files)) - - # The DAG from test_dag1.py is still active - assert session.get(DagModel, "test_dag1").is_stale is False - # and the DAG from test_dag2.py is deactivated - assert session.get(DagModel, "test_dag2").is_stale is True - @pytest.mark.parametrize( ("rel_filelocs", "expected_return", "expected_dag1_stale", "expected_dag2_stale"), [ @@ -1478,231 +1196,6 @@ def test_deactivate_deleted_dags_return_value( assert session.get(DagModel, "test_dag1").is_stale is expected_dag1_stale assert session.get(DagModel, "test_dag2").is_stale is expected_dag2_stale - @pytest.mark.parametrize( - ("active_files", "should_call_cleanup"), - [ - pytest.param( - [ - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag1.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - # test_dag2.py is deleted - ], - True, # Should call cleanup - id="dags_deactivated", - ), - pytest.param( - [ - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag1.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag2.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - ], - False, # Should NOT call cleanup - id="no_dags_deactivated", - ), - ], - ) - @mock.patch("airflow.dag_processing.manager.remove_references_to_deleted_dags") - def test_manager_deactivate_deleted_dags_cleanup_behavior( - self, mock_remove_references, dag_maker, session, active_files, should_call_cleanup - ): - """Test that manager conditionally calls remove_references_to_deleted_dags based on whether DAGs were deactivated.""" - with dag_maker("test_dag1") as dag1: - dag1.relative_fileloc = "test_dag1.py" - with dag_maker("test_dag2") as dag2: - dag2.relative_fileloc = "test_dag2.py" - dag_maker.sync_dagbag_to_db() - - manager = DagFileProcessorManager(max_runs=1) - manager.deactivate_deleted_dags("dag_maker", set(active_files)) - - if should_call_cleanup: - mock_remove_references.assert_called_once() - else: - mock_remove_references.assert_not_called() - - @conf_vars({("core", "load_examples"): "False"}) - def test_fetch_callbacks_from_database(self, configure_testing_dag_bundle): - """Test _fetch_callbacks_from_db returns callbacks ordered by priority_weight desc.""" - - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - callback1 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="123", - ) - callback2 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="456", - ) - - with create_session() as session: - session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) - session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - - with create_session() as session: - callbacks = manager._fetch_callbacks_from_db(session=session) - - # Should return callbacks ordered by priority_weight desc (highest first) - assert callbacks[0].run_id == "123" - assert callbacks[1].run_id == "456" - - assert len(session.scalars(select(DbCallbackRequest)).all()) == 0 - - @conf_vars( - { - ("dag_processor", "max_callbacks_per_loop"): "2", - ("core", "load_examples"): "False", - } - ) - def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_testing_dag_bundle): - """Test DagFileProcessorManager.fetch_callbacks method.""" - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - with create_session() as session: - for i in range(5): - callback = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id=str(i), - ) - session.add(DbCallbackRequest(callback=callback, priority_weight=i)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - - with create_session() as session: - manager.run() - assert len(session.scalars(select(DbCallbackRequest)).all()) == 3 - - with create_session() as session: - manager.run() - assert len(session.scalars(select(DbCallbackRequest)).all()) == 1 - - @conf_vars({("core", "load_examples"): "False"}) - def test_fetch_callbacks_ignores_other_bundles(self, configure_testing_dag_bundle): - """Ensure callbacks for bundles not owned by current dag processor manager are ignored and not deleted.""" - - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - # Create two callbacks: one for the active 'testing' bundle and one for a different bundle - matching = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="match", - ) - non_matching = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="other-bundle", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="no-match", - ) - - with create_session() as session: - session.add(DbCallbackRequest(callback=matching, priority_weight=100)) - session.add(DbCallbackRequest(callback=non_matching, priority_weight=200)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - - with create_session() as session: - callbacks = manager._fetch_callbacks_from_db(session=session) - - # Only the matching callback should be returned - assert [c.run_id for c in callbacks] == ["match"] - - # The non-matching callback should remain in the DB - remaining = session.scalars(select(DbCallbackRequest)).all() - assert len(remaining) == 1 - # Decode remaining request and verify it's for the other bundle - remaining_req = remaining[0].get_callback_request() - assert remaining_req.bundle_name == "other-bundle" - - @conf_vars( - { - ("dag_processor", "max_callbacks_per_loop"): "2", - ("core", "load_examples"): "False", - } - ) - def test_fetch_callbacks_filters_by_bundle_before_limit(self, configure_testing_dag_bundle): - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - matching = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="match", - ) - non_matching_1 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="other-bundle-a", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="no-match-1", - ) - non_matching_2 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="other-bundle-b", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="no-match-2", - ) - - with create_session() as session: - session.add(DbCallbackRequest(callback=non_matching_1, priority_weight=300)) - session.add(DbCallbackRequest(callback=non_matching_2, priority_weight=200)) - session.add(DbCallbackRequest(callback=matching, priority_weight=100)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - - with create_session() as session: - callbacks = manager._fetch_callbacks_from_db(session=session) - - assert [c.run_id for c in callbacks] == ["match"] - - remaining = session.scalars(select(DbCallbackRequest)).all() - assert len(remaining) == 2 - assert {callback.bundle_name for callback in remaining} == { - "other-bundle-a", - "other-bundle-b", - } - @mock.patch.object(DagFileProcessorManager, "_get_logger_for_dag_file") def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): mock_logger = MagicMock() @@ -1981,29 +1474,6 @@ def test_add_callback_skips_when_bundle_unconfigured(self, mock_bundle_manager): assert not manager._callback_to_execute - def test_fetch_callbacks_delegates_to_private_method(self): - manager = DagFileProcessorManager(max_runs=1) - expected: list = [mock.sentinel.callback] - with mock.patch.object(manager, "_fetch_callbacks_from_db", return_value=expected) as private: - assert manager.fetch_callbacks() is expected - private.assert_called_once_with() - - def test_dag_with_assets(self, session, configure_testing_dag_bundle): - """'Integration' test to ensure that the assets get parsed and stored correctly for parsed dags.""" - test_dag_path = str(TEST_DAG_FOLDER / "test_assets.py") - - with configure_testing_dag_bundle(test_dag_path): - manager = DagFileProcessorManager( - max_runs=1, - processor_timeout=365 * 86_400, - ) - manager.run() - - dag_model = session.get(DagModel, ("dag_with_skip_task")) - assert dag_model.task_outlet_asset_references == [ - TaskOutletAssetReference(asset_id=mock.ANY, dag_id="dag_with_skip_task", task_id="skip_task") - ] - def test_bundles_are_refreshed(self): """ Ensure bundles are refreshed by the manager, when necessary. @@ -2174,7 +1644,7 @@ def test_bundle_force_refresh(self): manager._refresh_dag_bundles({}) assert bundleone.refresh.call_count == 2 # forced refresh - def test_bundles_versions_are_stored(self, session): + def test_bundles_versions_are_stored(self, mock_dag_processing_client): config = [ { "name": "bundleone", @@ -2198,9 +1668,10 @@ def test_bundles_versions_are_stored(self, session): manager = DagFileProcessorManager(max_runs=1) manager.run() - with create_session() as session: - model = session.get(DagBundleModel, "bundleone") - assert model.version == "123" + # The refreshed bundle version is persisted via the DAG Processing API. + mock_dag_processing_client.update_bundle_state.assert_any_call( + "bundleone", last_refreshed=mock.ANY, version="123" + ) def test_non_versioned_bundle_get_version_not_called(self): config = [ @@ -2408,15 +1879,6 @@ def test_after_run_runs_when_parsing_loop_raises(self, tmp_path, configure_testi manager.run() after_run_mock.assert_called_once_with() - def test_prepare_server_process_context_can_be_skipped(self, tmp_path, configure_testing_dag_bundle): - """API-backed subclasses can skip server-context setup without losing selector/stats init.""" - with configure_testing_dag_bundle(tmp_path): - manager = DagFileProcessorManager(max_runs=1) - with mock.patch.object(manager, "prepare_server_process_context") as server_ctx_mock: - manager.run() - server_ctx_mock.assert_called_once_with() - assert manager.selector is not None - def test_prepare_bundles_can_be_overridden_without_sync(self, tmp_path, configure_testing_dag_bundle): """Subclasses can override `prepare_bundles` to skip `sync_bundles` (AIP-92 seam).""" with configure_testing_dag_bundle(tmp_path): @@ -2429,28 +1891,13 @@ def test_prepare_bundles_can_be_overridden_without_sync(self, tmp_path, configur sync_mock.assert_not_called() assert [b.name for b in manager._dag_bundles] == ["testing"] - def test_purge_inactive_dag_warnings_delegates_to_dagwarning(self): - """Default `purge_inactive_dag_warnings` calls `DagWarning.purge_inactive_dag_warnings`.""" - manager = DagFileProcessorManager(max_runs=1) - with mock.patch( - "airflow.dag_processing.manager.DagWarning.purge_inactive_dag_warnings" - ) as purge_mock: - manager.purge_inactive_dag_warnings() - purge_mock.assert_called_once_with() - def test_run_parsing_loop_uses_overridable_purge(self, tmp_path, configure_testing_dag_bundle): """`_run_parsing_loop` calls the overridable `purge_inactive_dag_warnings` seam.""" with configure_testing_dag_bundle(tmp_path): manager = DagFileProcessorManager(max_runs=1) - with ( - mock.patch.object(manager, "purge_inactive_dag_warnings") as purge_mock, - mock.patch( - "airflow.dag_processing.manager.DagWarning.purge_inactive_dag_warnings" - ) as direct_mock, - ): + with mock.patch.object(manager, "purge_inactive_dag_warnings") as purge_mock: manager.run() purge_mock.assert_called() - direct_mock.assert_not_called() @mock.patch("airflow.dag_processing.manager.stats.gauge") def test_stats_total_parse_time(self, statsd_gauge_mock, tmp_path, configure_testing_dag_bundle): @@ -2481,77 +1928,6 @@ def test_stats_total_parse_time(self, statsd_gauge_mock, tmp_path, configure_tes dag_path.touch() # make the loop run faster gauge_values.clear() - # --- get_bundle_state / update_bundle_state --- - - def test_get_bundle_state_returns_none_for_missing_bundle(self): - manager = DagFileProcessorManager(max_runs=1) - assert manager.get_bundle_state("nonexistent_bundle") is None - - def test_get_bundle_state_returns_correct_state(self, session): - bundle_name = "test_state_bundle" - refreshed_at = timezone.datetime(2024, 1, 15, 12, 0, 0) - model = DagBundleModel(name=bundle_name, version="v1") - model.last_refreshed = refreshed_at - session.add(model) - session.commit() - - manager = DagFileProcessorManager(max_runs=1) - state = manager.get_bundle_state(bundle_name) - - assert state == BundleState(last_refreshed=refreshed_at, version="v1") - - def test_get_bundle_state_null_fields(self, session): - bundle_name = "test_null_state_bundle" - session.add(DagBundleModel(name=bundle_name)) - session.commit() - - manager = DagFileProcessorManager(max_runs=1) - state = manager.get_bundle_state(bundle_name) - - assert state == BundleState(last_refreshed=None, version=None) - - def test_update_bundle_state_sets_last_refreshed(self, session): - bundle_name = "test_update_bundle" - session.add(DagBundleModel(name=bundle_name)) - session.commit() - - refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0) - manager = DagFileProcessorManager(max_runs=1) - manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, version=None) - - session.expire_all() - model = session.get(DagBundleModel, bundle_name) - assert model.last_refreshed == refreshed_at - assert model.version is None - - def test_update_bundle_state_sets_version(self, session): - bundle_name = "test_update_version_bundle" - session.add(DagBundleModel(name=bundle_name)) - session.commit() - - refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0) - manager = DagFileProcessorManager(max_runs=1) - manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, version="abc123") - - session.expire_all() - model = session.get(DagBundleModel, bundle_name) - assert model.last_refreshed == refreshed_at - assert model.version == "abc123" - - def test_update_bundle_state_does_not_overwrite_version_when_none(self, session): - bundle_name = "test_preserve_version_bundle" - session.add(DagBundleModel(name=bundle_name, version="keep_me")) - session.commit() - - refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0) - manager = DagFileProcessorManager(max_runs=1) - manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, version=None) - - session.expire_all() - model = session.get(DagBundleModel, bundle_name) - assert model.last_refreshed == refreshed_at - assert model.version == "keep_me" - def _make_refresh_bundle(self, *, supports_versioning=False, current_version=None): bundle = MagicMock(spec=BaseDagBundle) bundle.name = "mock_bundle" @@ -2572,14 +1948,13 @@ def _refresh_with_mocked_state(self, manager, bundle, initial_state): """ manager._dag_bundles = [bundle] manager._force_refresh_bundles = set() + manager._dag_processing_client = mock.MagicMock() mock_get = mock.patch.object(manager, "get_bundle_state", return_value=initial_state) mock_update = mock.patch.object(manager, "update_bundle_state") with ( mock_get as patched_get, mock_update as patched_update, mock.patch.object(manager, "_find_files_in_bundle", return_value=[]), - mock.patch.object(manager, "deactivate_deleted_dags"), - mock.patch.object(manager, "clear_orphaned_import_errors"), mock.patch.object(manager, "handle_removed_files"), mock.patch.object(manager, "_resort_file_queue"), mock.patch.object(manager, "_add_new_files_to_queue"), @@ -2640,6 +2015,7 @@ def test_refresh_dag_bundles_versioned_version_unchanged_persist_failure(self): manager._bundle_versions["mock_bundle"] = "v1" manager._dag_bundles = [bundle] manager._force_refresh_bundles = set() + manager._dag_processing_client = mock.MagicMock() known_files: dict[str, set[DagFileInfo]] = {} with ( @@ -2648,8 +2024,6 @@ def test_refresh_dag_bundles_versioned_version_unchanged_persist_failure(self): ), mock.patch.object(manager, "update_bundle_state", side_effect=Exception("DB error")), mock.patch.object(manager, "_find_files_in_bundle", return_value=[]) as mock_find, - mock.patch.object(manager, "deactivate_deleted_dags"), - mock.patch.object(manager, "clear_orphaned_import_errors"), mock.patch.object(manager, "handle_removed_files"), mock.patch.object(manager, "_resort_file_queue"), mock.patch.object(manager, "_add_new_files_to_queue"), @@ -2703,6 +2077,7 @@ def test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self) manager = DagFileProcessorManager(max_runs=1) bundle = self._make_refresh_bundle() manager._dag_bundles = [bundle] + manager._dag_processing_client = mock.MagicMock() known_files: dict[str, set[DagFileInfo]] = {} with ( @@ -2711,8 +2086,6 @@ def test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self) ), mock.patch.object(manager, "update_bundle_state", side_effect=Exception("API error")), mock.patch.object(manager, "_find_files_in_bundle", return_value=[]), - mock.patch.object(manager, "deactivate_deleted_dags"), - mock.patch.object(manager, "clear_orphaned_import_errors"), mock.patch.object(manager, "handle_removed_files"), mock.patch.object(manager, "_resort_file_queue"), mock.patch.object(manager, "_add_new_files_to_queue"), diff --git a/airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py b/airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py new file mode 100644 index 0000000000000..02c5eaab1e797 --- /dev/null +++ b/airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Tests for the API-backed persistence of DagFileProcessorManager (AIP-92). + +The DAG processor never reads or writes the metadata database directly; every persistence and +metadata operation is routed through the DAG Processing API client, and parse-time metadata +reads go through the remote Execution API. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from unittest import mock + +from airflow.dag_processing.manager import ( + BundleState, + DagFileInfo, + DagFileProcessorManager, + _api_token, + _dag_processing_api_server_url, +) + +from tests_common.test_utils.config import conf_vars + +MGR = "airflow.dag_processing.manager" + + +# --- DAG Processing API URL resolution + client construction --- + + +@conf_vars({("api", "base_url"): "http://my-host:9999/"}) +def test_dag_processing_api_server_url_derives_sibling_mount(): + assert _dag_processing_api_server_url() == "http://my-host:9999/dag-processing" + + +@conf_vars({("core", "dag_processing_api_server_url"): "http://explicit/dag-processing"}) +def test_dag_processing_api_server_url_prefers_explicit(): + assert _dag_processing_api_server_url() == "http://explicit/dag-processing" + + +def test_api_token_reads_externally_provisioned_file(tmp_path): + # The processor only carries a token a trusted component wrote; it never mints one. + token_file = tmp_path / "token" + token_file.write_text("provisioned-token\n") + with conf_vars({("dag_processor", "api_token_path"): str(token_file)}): + assert _api_token() == "provisioned-token" + + +def test_api_token_none_when_unset(): + with conf_vars({("dag_processor", "api_token_path"): ""}): + assert _api_token() is None + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_client_built_with_derived_default(mock_client_cls): + # No explicit URL -> always built, pointing at the derived /dag-processing mount. + manager = DagFileProcessorManager(max_runs=1) + assert manager._dag_processing_client is mock_client_cls.return_value + (url,) = mock_client_cls.call_args.args + assert url.endswith("/dag-processing") + + +@conf_vars( + { + ("core", "load_examples"): "False", + ("core", "dag_processing_api_server_url"): "http://api:8080/dag-processing", + } +) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_client_built_with_explicit_url(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + assert manager._dag_processing_client is mock_client_cls.return_value + # The client is built with the explicit URL (a self-signed token kwarg is also passed). + assert mock_client_cls.call_args.args[0] == "http://api:8080/dag-processing" + + +# --- every persistence/metadata method delegates to the client --- + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_persist_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + + manager.persist_parsing_result( + bundle_name="b1", + bundle_version="v1", + version_data=None, + parsing_result=mock.MagicMock(), + run_duration=1.0, + relative_fileloc="dags/a.py", + ) + + client.persist_parsing_result.assert_called_once() + kwargs = client.persist_parsing_result.call_args.kwargs + assert kwargs["bundle_name"] == "b1" + assert kwargs["relative_fileloc"] == "dags/a.py" + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_get_bundle_state_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.get_bundle_state.return_value = {"last_refreshed": None, "version": "v9"} + + state = manager.get_bundle_state("b1") + + client.get_bundle_state.assert_called_once_with("b1") + assert state == BundleState(last_refreshed=None, version="v9") + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_get_bundle_state_returns_none_from_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.get_bundle_state.return_value = None + + assert manager.get_bundle_state("b1") is None + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_update_bundle_state_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + ts = datetime.now(timezone.utc) + + manager.update_bundle_state("b1", last_refreshed=ts, version=None) + + client.update_bundle_state.assert_called_once_with("b1", last_refreshed=ts, version=None) + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_sync_bundles_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + + manager.sync_bundles() + + client.sync_bundles.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_deactivate_stale_dags_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + file_info = DagFileInfo(rel_path=Path("a.py"), bundle_name="b1") + + manager.deactivate_stale_dags(last_parsed={file_info: datetime.now(timezone.utc)}) + + client.deactivate_stale_dags.assert_called_once() + entries = client.deactivate_stale_dags.call_args.kwargs["last_parsed"] + assert entries[0]["bundle_name"] == "b1" + assert entries[0]["relative_fileloc"] == "a.py" + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_purge_inactive_dag_warnings_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + + manager.purge_inactive_dag_warnings() + + client.purge_inactive_dag_warnings.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_claim_priority_files_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.claim_priority_files.return_value = [{"bundle_name": "b1", "relative_fileloc": "a.py"}] + bundle = mock.MagicMock() + bundle.name = "b1" + bundle.path = Path("/bundles/b1") + manager._dag_bundles = [bundle] + + files = manager.claim_priority_files() + + client.claim_priority_files.assert_called_once_with(["b1"]) + assert len(files) == 1 + assert files[0].bundle_name == "b1" + assert str(files[0].rel_path) == "a.py" + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_fetch_callbacks_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.fetch_callbacks.return_value = ["callback-request"] + + result = manager.fetch_callbacks() + + client.fetch_callbacks.assert_called_once() + assert result == ["callback-request"] + + +# --- Job lifecycle runs through the API --- + + +def test_run_dag_processor_job_uses_api(): + from airflow.cli.commands.dag_processor_command import _run_dag_processor_job + + client = mock.MagicMock() + client.register_job.return_value = 99 + processor = mock.MagicMock() + processor._dag_processing_client = client + job_runner = mock.MagicMock(job_type="DagProcessorJob", processor=processor) + job_runner.job.heartrate = 0.0 # no throttle so the first heartbeat fires immediately + + _run_dag_processor_job(job_runner) + + client.register_job.assert_called_once_with("DagProcessorJob") + job_runner._execute.assert_called_once() + client.complete_job.assert_called_once_with(99, state="success") + # the heartbeat hook routes through the API client + processor.heartbeat() + client.job_heartbeat.assert_called_with(99) + + +# --- parse-time metadata reads go to the remote Execution API --- + + +@conf_vars( + { + ("core", "load_examples"): "False", + ("core", "execution_api_server_url"): "http://exec:8080/execution", + } +) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_parse_time_client_is_remote(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + with ( + mock.patch("airflow.dag_processing.manager._api_token", return_value="tok"), + mock.patch("airflow.dag_processing.manager.Client") as client_cls, + ): + _ = manager.client + + client_cls.assert_called_once() + kwargs = client_cls.call_args.kwargs + assert kwargs["base_url"] == "http://exec:8080/execution" + assert kwargs["token"] == "tok" + assert kwargs["dry_run"] is False + + +# --- bundle-init credentials resolve through the Execution API (no metadata DB) --- + + +def test_secrets_comms_resolves_connection_via_execution_api(): + from airflow.dag_processing.manager import _DagProcessorSecretsComms + from airflow.sdk.execution_time.comms import GetConnection + + client = mock.MagicMock() + comms = _DagProcessorSecretsComms(client) + with mock.patch( + "airflow.dag_processing.manager.handle_get_connection", + return_value=("conn-result", {}), + ) as handler: + result = comms.send(GetConnection(conn_id="git_default")) + + handler.assert_called_once() + assert handler.call_args.args[0] is client # routed through the Execution API client + assert result == "conn-result" + + +def test_secrets_comms_ignores_unsupported_messages(): + from airflow.dag_processing.manager import _DagProcessorSecretsComms + + comms = _DagProcessorSecretsComms(mock.MagicMock()) + # e.g. a MaskSecret emitted while masking a fetched secret: no-op, never touches the client. + assert comms.send(object()) is None + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_before_run_installs_secrets_comms(mock_client_cls): + from airflow.dag_processing.manager import _DagProcessorSecretsComms + from airflow.sdk.execution_time import task_runner + + manager = DagFileProcessorManager(max_runs=1) + manager.__dict__["client"] = mock.MagicMock() # pre-populate the cached_property + + sentinel = object() + old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel) + try: + manager._setup_secrets_comms() + assert isinstance(task_runner.SUPERVISOR_COMMS, _DagProcessorSecretsComms) + finally: + if old is sentinel: + if hasattr(task_runner, "SUPERVISOR_COMMS"): + del task_runner.SUPERVISOR_COMMS + else: + task_runner.SUPERVISOR_COMMS = old diff --git a/airflow-core/tests/unit/dag_processing/test_no_db_mode.py b/airflow-core/tests/unit/dag_processing/test_no_db_mode.py new file mode 100644 index 0000000000000..7712d446fa3bb --- /dev/null +++ b/airflow-core/tests/unit/dag_processing/test_no_db_mode.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""No-DB-mode harness for the DAG processor (AIP-92). + +The goal is a DAG processor that runs with zero metadata-DB access (all DB I/O +forwarded over the API). `forbid_db_access` makes any metadata-DB connection +raise, so tests can assert a code path is DB-free. Each phase that moves a DB +touchpoint behind the API flips one of the strict-xfail gates below to passing. +""" + +from __future__ import annotations + +import contextlib +from datetime import datetime, timezone +from unittest import mock + +import pytest + +from airflow.dag_processing.manager import DagFileProcessorManager + +from tests_common.test_utils.config import conf_vars + +MGR = "airflow.dag_processing.manager" +API_URL = "http://localhost:8080/dag-processing" + +# These tests patch ``settings.engine`` to forbid connections, so they need a real engine to +# patch. In the Non-DB test environment ``settings.engine`` is None, so run them as DB tests. +pytestmark = pytest.mark.db_test + + +@contextlib.contextmanager +def forbid_db_access(): + """Raise ``AssertionError`` if any metadata-DB connection is opened in the block.""" + from airflow import settings + + def _forbid(*args, **kwargs): + raise AssertionError("Metadata DB access is forbidden in no-DB mode") + + with ( + mock.patch.object(settings.engine, "connect", side_effect=_forbid), + mock.patch.object(settings.engine, "raw_connection", side_effect=_forbid), + ): + yield + + +@conf_vars({("core", "load_examples"): "False"}) +def test_forbid_db_access_blocks_real_queries(): + """The guard itself works: a real query under it raises.""" + from sqlalchemy import text + + from airflow.utils.session import create_session + + with forbid_db_access(), pytest.raises(AssertionError, match="forbidden"): + with create_session() as session: + session.execute(text("SELECT 1")).all() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_bundle_state_is_db_free_when_api_configured(mock_client_cls): + """Reading and updating bundle state go through the API, opening no DB connection.""" + client = mock_client_cls.return_value + client.get_bundle_state.return_value = None + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + assert manager.get_bundle_state("any-bundle") is None + manager.update_bundle_state("any-bundle", last_refreshed=datetime.now(timezone.utc), version="v1") + client.get_bundle_state.assert_called_once_with("any-bundle") + client.update_bundle_state.assert_called_once() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_sync_bundles_is_db_free_when_api_configured(mock_client_cls): + """Startup bundle sync goes through the API, opening no DB connection.""" + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + manager.sync_bundles() + mock_client_cls.return_value.sync_bundles.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_stale_sweep_is_db_free_when_api_configured(mock_client_cls): + """Time-based stale-dag deactivation and warning purge go through the API, no DB.""" + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + manager.deactivate_stale_dags(last_parsed={}) + manager.purge_inactive_dag_warnings() + client = mock_client_cls.return_value + client.deactivate_stale_dags.assert_called_once() + client.purge_inactive_dag_warnings.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_priority_claim_is_db_free_when_api_configured(mock_client_cls): + """Claiming priority parse requests goes through the API, opening no DB connection.""" + mock_client_cls.return_value.claim_priority_files.return_value = [] + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + assert manager.claim_priority_files() == [] + mock_client_cls.return_value.claim_priority_files.assert_called_once() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_callback_claim_is_db_free_when_api_configured(mock_client_cls): + """Fetching callbacks goes through the API, opening no DB connection.""" + mock_client_cls.return_value.fetch_callbacks.return_value = [] + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + assert manager.fetch_callbacks() == [] + mock_client_cls.return_value.fetch_callbacks.assert_called_once() diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ef92372f57378..4069299ea5004 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -50,7 +50,7 @@ TaskCallbackRequest, ) from airflow.dag_processing.dagbag import DagBag -from airflow.dag_processing.manager import DagFileProcessorManager, process_parse_results +from airflow.dag_processing.manager import process_parse_results from airflow.dag_processing.processor import ( DagFileParseRequest, DagFileParsingResult, @@ -706,43 +706,6 @@ def test_import_error_updates_timestamps(): assert stat.import_errors == 1 -def test_persist_parsing_result_calls_update_db(): - """persist_parsing_result should delegate to update_dag_parsing_results_in_db with transformed args.""" - parsing_result = DagFileParsingResult( - fileloc="test.py", - serialized_dags=[], - import_errors={"dags/broken.py": "SyntaxError"}, - warnings=[], - ) - - manager = MagicMock(spec=DagFileProcessorManager) - session = MagicMock() - # Call the real method on the mock instance - with patch( - "airflow.dag_processing.manager.update_dag_parsing_results_in_db", autospec=True - ) as mock_update: - DagFileProcessorManager.persist_parsing_result( - manager, - bundle_name="test-bundle", - bundle_version="v1", - version_data=None, - parsing_result=parsing_result, - run_duration=1.5, - relative_fileloc="dags/test.py", - session=session, - ) - - mock_update.assert_called_once() - call_kwargs = mock_update.call_args.kwargs - assert call_kwargs["bundle_name"] == "test-bundle" - assert call_kwargs["bundle_version"] == "v1" - assert call_kwargs["parse_duration"] == 1.5 - assert call_kwargs["session"] is session - assert call_kwargs["import_errors"] == {("test-bundle", "dags/broken.py"): "SyntaxError"} - assert ("test-bundle", "dags/test.py") in call_kwargs["files_parsed"] - assert ("test-bundle", "dags/broken.py") in call_kwargs["files_parsed"] - - class TestExecuteCallbacks: def test_execute_callbacks_locks_bundle_version(self): callbacks = [ From 71025e4b40a9d8e98556427de127704152b457b7 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 3 Jun 2026 00:38:44 +0100 Subject: [PATCH 2/7] Add the DAG Processing API server app (AIP-92) Mount a /dag-processing FastAPI app on the API server with the endpoints the DAG processor persists through: parsing-results, bundle reconcile/state/sync, stale-dags, purge-warnings, priority-parse and callback claim, and Job register/heartbeat/complete. Split into app.py (routes), datamodels.py, and security.py, matching the execution_api layout. The endpoints validate the bearer token the processor presents via JWTBearer, using the same get_sig_validation_args path as the Execution API, so [api_auth] trusted_jwks_url applies equally. /health is open for readiness probes. --- airflow-core/src/airflow/api_fastapi/app.py | 6 +- .../api_fastapi/dag_processing/__init__.py | 25 ++ .../airflow/api_fastapi/dag_processing/app.py | 314 +++++++++++++++ .../api_fastapi/dag_processing/datamodels.py | 97 +++++ .../api_fastapi/dag_processing/security.py | 73 ++++ .../api_fastapi/dag_processing/__init__.py | 17 + .../api_fastapi/dag_processing/test_app.py | 366 ++++++++++++++++++ .../api_fastapi/dag_processing/test_app_db.py | 291 ++++++++++++++ 8 files changed, 1188 insertions(+), 1 deletion(-) create mode 100644 airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py create mode 100644 airflow-core/src/airflow/api_fastapi/dag_processing/app.py create mode 100644 airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py create mode 100644 airflow-core/src/airflow/api_fastapi/dag_processing/security.py create mode 100644 airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py create mode 100644 airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py create mode 100644 airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py diff --git a/airflow-core/src/airflow/api_fastapi/app.py b/airflow-core/src/airflow/api_fastapi/app.py index 8931840c8807c..a8425658407b8 100644 --- a/airflow-core/src/airflow/api_fastapi/app.py +++ b/airflow-core/src/airflow/api_fastapi/app.py @@ -34,6 +34,7 @@ init_middlewares, init_views, ) +from airflow.api_fastapi.dag_processing.app import create_dag_processing_api_app from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -61,7 +62,7 @@ def get_cookie_path() -> str: # Fast API apps mounted under these prefixes are not allowed -RESERVED_URL_PREFIXES = ["/api/v2", "/ui", "/execution", "/auth", "/pluginsv2"] +RESERVED_URL_PREFIXES = ["/api/v2", "/ui", "/execution", "/auth", "/pluginsv2", "/dag-processing"] log = logging.getLogger(__name__) @@ -107,6 +108,9 @@ def create_app(apps: str = "all") -> FastAPI: init_error_handlers(task_exec_api_app) app.mount("/execution", task_exec_api_app) + if "all" in apps_list or "dag-processing" in apps_list: + app.mount("/dag-processing", create_dag_processing_api_app()) + if "all" in apps_list or "core" in apps_list: app.state.dag_bag = dag_bag init_plugins(app) diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py b/airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py new file mode 100644 index 0000000000000..fb62457db1aaf --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +DAG Processing API (AIP-92). + +A FastAPI sub-app that lets the DAG processor persist parse results without a +direct metadata-DB connection. The processor parses files locally and POSTs the +results here; this app owns the DB writes. +""" + +from __future__ import annotations diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/app.py b/airflow-core/src/airflow/api_fastapi/dag_processing/app.py new file mode 100644 index 0000000000000..e7e04dbd014f6 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/app.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DAG Processing API sub-app (AIP-92): persistence endpoints for the DAG processor.""" + +from __future__ import annotations + +import logging +from datetime import timedelta + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from sqlalchemy import select, update +from sqlalchemy.orm import load_only + +from airflow._shared.timezones import timezone +from airflow.api_fastapi.dag_processing.datamodels import ( + BundleStateResponse, + BundleStateUpdateBody, + CallbackClaimBody, + JobCompleteBody, + JobRegisterBody, + ParsingResultBody, + PriorityClaimBody, + ReconcileBody, + StaleDagsBody, +) +from airflow.api_fastapi.dag_processing.security import require_dag_processing_auth +from airflow.dag_processing.bundles.manager import DagBundlesManager +from airflow.dag_processing.collection import update_dag_parsing_results_in_db +from airflow.jobs.job import Job +from airflow.models.asset import remove_references_to_deleted_dags +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.models.dagbundle import DagBundleModel +from airflow.models.dagwarning import DagWarning +from airflow.models.db_callback_request import DbCallbackRequest +from airflow.models.errors import ParseImportError +from airflow.serialization.serialized_objects import LazyDeserializedDAG +from airflow.utils.session import create_session +from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks + +log = logging.getLogger(__name__) + +router = APIRouter() + + +def health() -> dict: + return {"status": "healthy"} + + +@router.post("/parsing-results", status_code=201) +def persist_parsing_results(body: ParsingResultBody) -> dict: + """Persist one file's parse results. Mirrors ``persist_parsing_result``.""" + try: + dags = [LazyDeserializedDAG.model_validate(d) for d in body.serialized_dags] + except Exception as e: + raise HTTPException(status_code=422, detail=f"Invalid serialized_dags payload: {e}") + + import_errors: dict[tuple[str, str], str] = {} + if body.import_errors: + import_errors = {(body.bundle_name, rel): err for rel, err in body.import_errors.items()} + + files_parsed: set[tuple[str, str]] | None = None + if body.relative_fileloc is not None: + files_parsed = {(body.bundle_name, body.relative_fileloc)} + files_parsed.update(import_errors.keys()) + + warnings = [DagWarning(**warn) for warn in (body.warnings or [])] + + with create_session() as session: + update_dag_parsing_results_in_db( + bundle_name=body.bundle_name, + bundle_version=body.bundle_version, + version_data=body.version_data, + dags=dags, + import_errors=import_errors, + parse_duration=body.run_duration, + warnings=set(warnings), + session=session, + files_parsed=files_parsed, + ) + + return {"bundle_name": body.bundle_name, "relative_fileloc": body.relative_fileloc} + + +@router.post("/bundles/{bundle_name}/reconcile") +def reconcile_bundle(bundle_name: str, body: ReconcileBody) -> dict: + """ + Deactivate DAGs/import-errors for files no longer present in the bundle. + + Mirrors ``deactivate_deleted_dags`` + ``clear_orphaned_import_errors``. These run in + separate transactions, and import-error cleanup swallows its own errors, so a cleanup + failure cannot roll back the deactivations (matching the original behaviour). + """ + observed = set(body.observed_filelocs) + + with create_session() as session: + deactivated = DagModel.deactivate_deleted_dags( + bundle_name=bundle_name, + rel_filelocs=observed, + session=session, + ) + if deactivated: + remove_references_to_deleted_dags(session=session) + + try: + with create_session() as session: + errors = session.scalars( + select(ParseImportError) + .where(ParseImportError.bundle_name == bundle_name) + .options(load_only(ParseImportError.filename)) + ) + for error in errors: + if error.filename not in observed: + session.delete(error) + except Exception: + log.exception("Error removing old import errors for bundle %s", bundle_name) + + return {"bundle_name": bundle_name, "deactivated": bool(deactivated)} + + +@router.get("/bundles/{bundle_name}/state", response_model=BundleStateResponse) +def get_bundle_state(bundle_name: str) -> BundleStateResponse: + """Return a bundle's persisted refresh state (last_refreshed + version).""" + with create_session() as session: + row = session.scalar( + select(DagBundleModel) + .where(DagBundleModel.name == bundle_name) + .options(load_only(DagBundleModel.last_refreshed, DagBundleModel.version)) + ) + if row is None: + return BundleStateResponse(found=False) + return BundleStateResponse(found=True, last_refreshed=row.last_refreshed, version=row.version) + + +@router.patch("/bundles/{bundle_name}/state") +def update_bundle_state(bundle_name: str, body: BundleStateUpdateBody) -> dict: + """Persist a bundle's post-refresh state. Updates ``version`` only when provided.""" + values: dict = {"last_refreshed": body.last_refreshed} + if body.version is not None: + values["version"] = body.version + with create_session() as session: + session.execute(update(DagBundleModel).where(DagBundleModel.name == bundle_name).values(**values)) + return {"bundle_name": bundle_name} + + +@router.post("/bundles/sync") +def sync_bundles() -> dict: + """Sync the configured DAG bundles to the metadata database (server-side).""" + DagBundlesManager().sync_bundles_to_db() + return {"synced": True} + + +@router.post("/stale-dags") +def deactivate_stale_dags(body: StaleDagsBody) -> dict: + """ + Deactivate DAGs whose files have not been re-parsed within the stale threshold. + + Mirrors ``DagFileProcessorManager.deactivate_stale_dags`` server-side. + """ + last_parsed = {(e.bundle_name, e.relative_fileloc): e.last_finish_time for e in body.last_parsed} + to_deactivate: set[str] = set() + deactivated = 0 + with create_session() as session: + inactive_bundles = set( + session.scalars(select(DagBundleModel.name).where(DagBundleModel.active.is_(False))).all() + ) + rows = session.execute( + select( + DagModel.dag_id, + DagModel.bundle_name, + DagModel.last_parsed_time, + DagModel.relative_fileloc, + ).where(~DagModel.is_stale) + ) + for row in rows: + if row.bundle_name in inactive_bundles: + to_deactivate.add(row.dag_id) + continue + last_finish_time = last_parsed.get((row.bundle_name, row.relative_fileloc)) + if last_finish_time and ( + row.last_parsed_time + timedelta(seconds=body.stale_dag_threshold) < last_finish_time + ): + to_deactivate.add(row.dag_id) + if to_deactivate: + result = session.execute( + update(DagModel) + .where(DagModel.dag_id.in_(to_deactivate)) + .values(is_stale=True) + .execution_options(synchronize_session="fetch") + ) + deactivated = getattr(result, "rowcount", 0) + return {"deactivated": deactivated} + + +@router.post("/purge-warnings") +def purge_inactive_dag_warnings() -> dict: + """Delete warnings for inactive/stale DAGs (server-side).""" + DagWarning.purge_inactive_dag_warnings() + return {"purged": True} + + +@router.post("/priority-parse-requests/claim") +def claim_priority_parse_requests(body: PriorityClaimBody) -> dict: + """Claim (select + delete, one transaction) priority parse requests for the given bundles.""" + claimed: list[dict] = [] + with create_session() as session: + requests = session.scalars( + select(DagPriorityParsingRequest).where( + DagPriorityParsingRequest.bundle_name.in_(body.bundle_names) + ) + ) + for request in requests: + claimed.append({"bundle_name": request.bundle_name, "relative_fileloc": request.relative_fileloc}) + session.delete(request) + return {"claimed": claimed} + + +@router.post("/callbacks/claim") +def claim_callbacks(body: CallbackClaimBody) -> dict: + """ + Claim callbacks for the given bundles using FOR UPDATE SKIP LOCKED, server-side. + + Mirrors ``DagFileProcessorManager._fetch_callbacks_from_db``. Returns each claimed + callback's raw ``{req_class, req_data}`` payload so the caller can rebuild the typed + ``CallbackRequest`` exactly as ``DbCallbackRequest.get_callback_request`` does. + """ + claimed: list[dict] = [] + with create_session() as session: + with prohibit_commit(session) as guard: + query = with_row_locks( + select(DbCallbackRequest) + .where(DbCallbackRequest.bundle_name.in_(body.bundle_names)) + .order_by(DbCallbackRequest.priority_weight.desc()) + .limit(body.limit), + of=DbCallbackRequest, + session=session, + skip_locked=True, + ) + callbacks = [cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query)] + for callback in callbacks: + claimed.append(callback.data) + session.delete(callback) + guard.commit() + return {"callbacks": claimed} + + +@router.post("/jobs", status_code=201) +def register_job(body: JobRegisterBody) -> dict: + """Register the processor's liveness Job row (server-side) and return its id.""" + job = Job() + job.job_type = body.job_type + with create_session() as session: + job.prepare_for_execution(session=session) + return {"job_id": job.id} + + +@router.post("/jobs/{job_id}/heartbeat") +def job_heartbeat(job_id: int) -> dict: + """Update the processor Job's latest_heartbeat so the health check sees it alive.""" + with create_session() as session: + job = session.get(Job, job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + job.latest_heartbeat = timezone.utcnow() + session.merge(job) + return {"alive": True} + + +@router.post("/jobs/{job_id}/complete") +def complete_job(job_id: int, body: JobCompleteBody) -> dict: + """Record the processor Job's terminal state and end time.""" + with create_session() as session: + job = session.get(Job, job_id) + if job is not None: + job.end_date = timezone.utcnow() + job.state = body.state + session.merge(job) + return {"completed": True} + + +def create_dag_processing_api_app() -> FastAPI: + """Create the DAG Processing API sub-app (mounted at ``/dag-processing``).""" + app = FastAPI( + title="Airflow DAG Processing API", + description="Persistence endpoints for the DAG processor (AIP-92).", + ) + + @app.exception_handler(Exception) + async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + # Mounted sub-apps build their own middleware stack, so without this the parent + # app never logs unhandled exceptions raised here. + log.exception("Unhandled exception in DAG Processing API: %s %s", request.method, request.url.path) + return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) + + # /health stays unauthenticated so external readiness probes work; every persistence + # endpoint requires a valid DAG-processor token. + app.add_api_route("/health", health, methods=["GET"]) + app.include_router(router, dependencies=[Depends(require_dag_processing_auth)]) + return app diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py b/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py new file mode 100644 index 0000000000000..130edb0adf739 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Request/response models for the DAG Processing API (AIP-92).""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel + + +class ParsingResultBody(BaseModel): + """One file's parse output, mirroring ``DagFileProcessorManager.persist_parsing_result``.""" + + bundle_name: str + bundle_version: str | None = None + version_data: dict | None = None + relative_fileloc: str | None = None + run_duration: float | None = None + serialized_dags: list[dict] = [] + import_errors: dict[str, str] | None = None + warnings: list[dict] | None = None + + +class ReconcileBody(BaseModel): + """The full set of source paths currently observed in a bundle, for stale reconciliation.""" + + observed_filelocs: list[str] + + +class BundleStateResponse(BaseModel): + """A bundle's persisted refresh state (``found=False`` when it has no record).""" + + found: bool + last_refreshed: datetime | None = None + version: str | None = None + + +class BundleStateUpdateBody(BaseModel): + """Post-refresh state for a bundle. ``version=None`` leaves the stored version unchanged.""" + + last_refreshed: datetime + version: str | None = None + + +class StaleDagEntry(BaseModel): + """A parsed file's identity and last parse-finish time, for the stale-dag sweep.""" + + bundle_name: str + relative_fileloc: str + last_finish_time: datetime + + +class StaleDagsBody(BaseModel): + """Inputs for the time-based stale-dag sweep.""" + + stale_dag_threshold: int + last_parsed: list[StaleDagEntry] + + +class PriorityClaimBody(BaseModel): + """Bundle names to claim priority parse requests for.""" + + bundle_names: list[str] + + +class CallbackClaimBody(BaseModel): + """Bundle names and per-loop limit for claiming callbacks.""" + + bundle_names: list[str] + limit: int + + +class JobRegisterBody(BaseModel): + """Job type to register for the processor's liveness record.""" + + job_type: str + + +class JobCompleteBody(BaseModel): + """Terminal state to record when the processor stops.""" + + state: str diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/security.py b/airflow-core/src/airflow/api_fastapi/dag_processing/security.py new file mode 100644 index 0000000000000..0449fe9b3cb98 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/security.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Authentication for the DAG Processing API (AIP-92).""" + +from __future__ import annotations + +import functools +import logging + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from airflow.api_fastapi.auth.tokens import JWTValidator, get_sig_validation_args +from airflow.configuration import conf + +log = logging.getLogger(__name__) + + +@functools.lru_cache(maxsize=1) +def _token_validator() -> JWTValidator: + """ + Build the JWT validator for the DAG Processing API. + + ``get_sig_validation_args`` selects symmetric (deployment ``jwt_secret``) or asymmetric/JWKS + validation (``[api_auth] trusted_jwks_url``) the same way the Execution API does, so a + multi-tenant deployment validates externally-issued tokens without any change here. + """ + required_claims = frozenset({"aud", "exp", "iat"}) + issuer = conf.get("api_auth", "jwt_issuer", fallback=None) + if issuer: + required_claims = required_claims | {"iss"} + return JWTValidator( + required_claims=required_claims, + issuer=issuer, + audience=conf.get_mandatory_list_value("dag_processor", "jwt_audience"), + **get_sig_validation_args(make_secret_key_if_needed=False), + ) + + +_bearer = HTTPBearer(auto_error=False) + + +def require_dag_processing_auth( + creds: HTTPAuthorizationCredentials | None = Depends(_bearer), +) -> None: + """ + Reject DAG Processing API calls that lack a valid DAG-processor token. + + The DAG processor self-signs a token for the configured ``[dag_processor] jwt_audience`` with + the deployment signing key; this validates the signature, audience, and issuer so the + persistence endpoints are not callable unauthenticated. + """ + if creds is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + try: + _token_validator().validated_claims(creds.credentials) + except Exception: + log.warning("Invalid DAG Processing API token", exc_info=True) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid auth token") diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py b/airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py new file mode 100644 index 0000000000000..21d298ede6ed3 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py new file mode 100644 index 0000000000000..84b133860a13a --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from fastapi.testclient import TestClient + +from airflow.api_fastapi.dag_processing.app import create_dag_processing_api_app +from airflow.api_fastapi.dag_processing.security import require_dag_processing_auth + +from tests_common.test_utils.config import conf_vars + +APP = "airflow.api_fastapi.dag_processing.app" + + +@pytest.fixture +def client(): + # Endpoint-logic tests bypass the token check; auth itself is covered separately below. + app = create_dag_processing_api_app() + app.dependency_overrides[require_dag_processing_auth] = lambda: None + return TestClient(app) + + +def test_health(client): + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "healthy"} + + +def test_persistence_endpoint_requires_a_token(): + """A persistence endpoint rejects an unauthenticated call (no token).""" + unauth = TestClient(create_dag_processing_api_app(), raise_server_exceptions=False) + resp = unauth.post("/bundles/sync") + assert resp.status_code == 401 + + +def test_persistence_endpoint_rejects_invalid_token(): + """A persistence endpoint rejects a malformed/garbage token.""" + unauth = TestClient(create_dag_processing_api_app(), raise_server_exceptions=False) + resp = unauth.post("/bundles/sync", headers={"Authorization": "Bearer not-a-real-token"}) + assert resp.status_code == 403 + + +def test_health_does_not_require_a_token(): + """``/health`` stays open so external readiness probes work.""" + unauth = TestClient(create_dag_processing_api_app()) + assert unauth.get("/health").status_code == 200 + + +def test_persistence_endpoint_accepts_a_valid_token(): + """A token signed with the deployment key for the DAG-processing audience is accepted.""" + from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args + from airflow.api_fastapi.dag_processing.security import _token_validator + from airflow.configuration import conf + + with conf_vars({("api_auth", "jwt_secret"): "dag-processing-test-secret"}): + _token_validator.cache_clear() + # A trusted component (not the DAG processor) mints the token the processor presents. + # Mirror the audience/issuer the validator reads so the token is accepted. + token = JWTGenerator( + valid_for=300, + audience=conf.get_mandatory_list_value("dag_processor", "jwt_audience")[0], + issuer=conf.get("api_auth", "jwt_issuer", fallback=None), + **get_signing_args(make_secret_key_if_needed=False), + ).generate({"sub": "dag-processor"}) + authed = TestClient(create_dag_processing_api_app(), raise_server_exceptions=False) + with mock.patch(f"{APP}.DagBundlesManager"): + resp = authed.post("/bundles/sync", headers={"Authorization": f"Bearer {token}"}) + _token_validator.cache_clear() + assert resp.status_code == 200 + + +def test_persist_parsing_results_forwards_to_update_db(client): + with ( + mock.patch(f"{APP}.update_dag_parsing_results_in_db") as update_db, + mock.patch(f"{APP}.LazyDeserializedDAG") as lazy_dag, + mock.patch(f"{APP}.create_session"), + ): + lazy_dag.model_validate.side_effect = lambda d: d + body = { + "bundle_name": "b1", + "bundle_version": "v1", + "relative_fileloc": "dags/a.py", + "run_duration": 1.5, + "serialized_dags": [{"data": {"dag": {"dag_id": "x"}}}], + "import_errors": {"dags/a.py": "boom"}, + "warnings": [], + } + resp = client.post("/parsing-results", json=body) + + assert resp.status_code == 201 + update_db.assert_called_once() + kwargs = update_db.call_args.kwargs + assert kwargs["bundle_name"] == "b1" + assert kwargs["bundle_version"] == "v1" + # import errors are re-keyed to (bundle_name, relative_fileloc) tuples + assert kwargs["import_errors"] == {("b1", "dags/a.py"): "boom"} + # files_parsed includes the parsed file plus any import-error files + assert ("b1", "dags/a.py") in kwargs["files_parsed"] + + +def test_persist_parsing_results_without_fileloc_has_no_files_parsed(client): + with ( + mock.patch(f"{APP}.update_dag_parsing_results_in_db") as update_db, + mock.patch(f"{APP}.LazyDeserializedDAG"), + mock.patch(f"{APP}.create_session"), + ): + resp = client.post("/parsing-results", json={"bundle_name": "b1", "serialized_dags": []}) + + assert resp.status_code == 201 + assert update_db.call_args.kwargs["files_parsed"] is None + + +def test_persist_parsing_results_invalid_payload_returns_422(client): + with ( + mock.patch(f"{APP}.LazyDeserializedDAG") as lazy_dag, + mock.patch(f"{APP}.update_dag_parsing_results_in_db") as update_db, + ): + lazy_dag.model_validate.side_effect = ValueError("bad blob") + resp = client.post("/parsing-results", json={"bundle_name": "b1", "serialized_dags": [{"data": {}}]}) + + assert resp.status_code == 422 + update_db.assert_not_called() + + +def test_reconcile_deactivates_and_removes_references(client): + with ( + mock.patch(f"{APP}.DagModel") as dag_model, + mock.patch(f"{APP}.remove_references_to_deleted_dags") as remove_refs, + mock.patch(f"{APP}.create_session") as create_session, + ): + dag_model.deactivate_deleted_dags.return_value = True + create_session.return_value.__enter__.return_value.scalars.return_value = [] + resp = client.post("/bundles/b1/reconcile", json={"observed_filelocs": ["dags/a.py"]}) + + assert resp.status_code == 200 + assert resp.json() == {"bundle_name": "b1", "deactivated": True} + dag_model.deactivate_deleted_dags.assert_called_once() + remove_refs.assert_called_once() + + +def test_reconcile_skips_remove_references_when_nothing_deactivated(client): + with ( + mock.patch(f"{APP}.DagModel") as dag_model, + mock.patch(f"{APP}.remove_references_to_deleted_dags") as remove_refs, + mock.patch(f"{APP}.create_session") as create_session, + ): + dag_model.deactivate_deleted_dags.return_value = False + create_session.return_value.__enter__.return_value.scalars.return_value = [] + resp = client.post("/bundles/b1/reconcile", json={"observed_filelocs": []}) + + assert resp.status_code == 200 + assert resp.json()["deactivated"] is False + remove_refs.assert_not_called() + + +def test_reconcile_import_error_cleanup_failure_is_swallowed(client): + """A cleanup failure must not fail the request nor (being a separate txn) undo deactivation.""" + with ( + mock.patch(f"{APP}.DagModel") as dag_model, + mock.patch(f"{APP}.remove_references_to_deleted_dags"), + mock.patch(f"{APP}.create_session") as create_session, + ): + dag_model.deactivate_deleted_dags.return_value = True + # The import-error cleanup block calls session.scalars(); make it blow up. + create_session.return_value.__enter__.return_value.scalars.side_effect = RuntimeError("db down") + resp = client.post("/bundles/b1/reconcile", json={"observed_filelocs": ["dags/a.py"]}) + + assert resp.status_code == 200 + # Deactivation still ran (in its own, earlier transaction). + dag_model.deactivate_deleted_dags.assert_called_once() + + +def test_get_bundle_state_found(client): + from datetime import datetime, timezone + + row = mock.MagicMock(last_refreshed=datetime(2024, 1, 1, tzinfo=timezone.utc), version="v1") + with ( + mock.patch(f"{APP}.create_session") as create_session, + ): + create_session.return_value.__enter__.return_value.scalar.return_value = row + resp = client.get("/bundles/b1/state") + + assert resp.status_code == 200 + body = resp.json() + assert body["found"] is True + assert body["version"] == "v1" + + +def test_get_bundle_state_not_found(client): + with ( + mock.patch(f"{APP}.create_session") as create_session, + ): + create_session.return_value.__enter__.return_value.scalar.return_value = None + resp = client.get("/bundles/missing/state") + + assert resp.status_code == 200 + assert resp.json() == {"found": False, "last_refreshed": None, "version": None} + + +def test_update_bundle_state_includes_version_when_provided(client): + with ( + mock.patch(f"{APP}.update") as update_stmt, + mock.patch(f"{APP}.create_session") as create_session, + ): + resp = client.patch( + "/bundles/b1/state", + json={"last_refreshed": "2024-06-01T08:00:00+00:00", "version": "abc"}, + ) + + assert resp.status_code == 200 + # version supplied -> both columns written + update_stmt.return_value.where.return_value.values.assert_called_once_with( + last_refreshed=mock.ANY, version="abc" + ) + create_session.return_value.__enter__.return_value.execute.assert_called_once() + + +def test_update_bundle_state_omits_version_when_null(client): + with ( + mock.patch(f"{APP}.update") as update_stmt, + mock.patch(f"{APP}.create_session"), + ): + resp = client.patch( + "/bundles/b1/state", + json={"last_refreshed": "2024-06-01T08:00:00+00:00", "version": None}, + ) + + assert resp.status_code == 200 + # version omitted -> only last_refreshed written (stored version left unchanged) + update_stmt.return_value.where.return_value.values.assert_called_once_with(last_refreshed=mock.ANY) + + +def test_sync_bundles_triggers_server_side_sync(client): + with mock.patch(f"{APP}.DagBundlesManager") as dbm: + resp = client.post("/bundles/sync") + + assert resp.status_code == 200 + dbm.return_value.sync_bundles_to_db.assert_called_once_with() + + +def test_deactivate_stale_dags_inactive_bundle_and_threshold(client): + from datetime import datetime, timedelta, timezone + + now = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + r_inactive = mock.MagicMock( + dag_id="d_inactive", bundle_name="dead", last_parsed_time=now, relative_fileloc="a.py" + ) + r_stale = mock.MagicMock( + dag_id="d_stale", + bundle_name="live", + last_parsed_time=now - timedelta(hours=2), + relative_fileloc="b.py", + ) + r_fresh = mock.MagicMock( + dag_id="d_fresh", bundle_name="live", last_parsed_time=now, relative_fileloc="c.py" + ) + with mock.patch(f"{APP}.create_session") as create_session: + session = create_session.return_value.__enter__.return_value + session.scalars.return_value.all.return_value = ["dead"] + session.execute.side_effect = [[r_inactive, r_stale, r_fresh], mock.MagicMock(rowcount=2)] + body = { + "stale_dag_threshold": 60, + "last_parsed": [ + {"bundle_name": "live", "relative_fileloc": "b.py", "last_finish_time": now.isoformat()}, + {"bundle_name": "live", "relative_fileloc": "c.py", "last_finish_time": now.isoformat()}, + ], + } + resp = client.post("/stale-dags", json=body) + + assert resp.status_code == 200 + # d_inactive (dead bundle) + d_stale (past threshold); d_fresh stays. + assert resp.json()["deactivated"] == 2 + + +def test_purge_warnings_endpoint(client): + with mock.patch(f"{APP}.DagWarning") as dag_warning: + resp = client.post("/purge-warnings") + + assert resp.status_code == 200 + dag_warning.purge_inactive_dag_warnings.assert_called_once_with() + + +def test_claim_priority_parse_requests(client): + req = mock.MagicMock(bundle_name="b1", relative_fileloc="dags/a.py") + with mock.patch(f"{APP}.create_session") as create_session: + session = create_session.return_value.__enter__.return_value + session.scalars.return_value = [req] + resp = client.post("/priority-parse-requests/claim", json={"bundle_names": ["b1"]}) + + assert resp.status_code == 200 + assert resp.json()["claimed"] == [{"bundle_name": "b1", "relative_fileloc": "dags/a.py"}] + session.delete.assert_called_once_with(req) + + +def test_claim_callbacks_skip_locked_and_delete(client): + cb = mock.MagicMock(data={"req_class": "DagCallbackRequest", "req_data": "{}"}) + with ( + mock.patch(f"{APP}.with_row_locks"), + mock.patch(f"{APP}.prohibit_commit") as prohibit, + mock.patch(f"{APP}.create_session") as create_session, + ): + session = create_session.return_value.__enter__.return_value + session.scalars.return_value = [cb] + resp = client.post("/callbacks/claim", json={"bundle_names": ["b1"], "limit": 5}) + + assert resp.status_code == 200 + assert resp.json()["callbacks"] == [{"req_class": "DagCallbackRequest", "req_data": "{}"}] + session.delete.assert_called_once_with(cb) + # the claim is committed via the prohibit_commit guard + prohibit.return_value.__enter__.return_value.commit.assert_called_once() + + +def test_register_job(client): + job = mock.MagicMock(id=42) + with mock.patch(f"{APP}.Job", return_value=job), mock.patch(f"{APP}.create_session"): + resp = client.post("/jobs", json={"job_type": "DagProcessorJob"}) + + assert resp.status_code == 201 + assert resp.json() == {"job_id": 42} + job.prepare_for_execution.assert_called_once() + + +def test_job_heartbeat_updates_latest_heartbeat(client): + job = mock.MagicMock() + with mock.patch(f"{APP}.create_session") as create_session: + create_session.return_value.__enter__.return_value.get.return_value = job + resp = client.post("/jobs/7/heartbeat") + + assert resp.status_code == 200 + assert resp.json() == {"alive": True} + assert job.latest_heartbeat is not None + + +def test_job_heartbeat_missing_returns_404(client): + with mock.patch(f"{APP}.create_session") as create_session: + create_session.return_value.__enter__.return_value.get.return_value = None + resp = client.post("/jobs/7/heartbeat") + + assert resp.status_code == 404 + + +def test_complete_job_sets_state(client): + job = mock.MagicMock() + with mock.patch(f"{APP}.create_session") as create_session: + create_session.return_value.__enter__.return_value.get.return_value = job + resp = client.post("/jobs/7/complete", json={"state": "success"}) + + assert resp.status_code == 200 + assert job.state == "success" diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py new file mode 100644 index 0000000000000..1ef15ca3d0263 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py @@ -0,0 +1,291 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Real-database tests for the DAG Processing API persistence endpoints. + +These complement ``test_app.py`` (which mocks ``create_session``/``with_row_locks``) +by running the endpoints against a real session, so the actual SQL -- ordering, +limit, WHERE clauses, lock, and delete -- is exercised end to end. +""" + +from __future__ import annotations + +import json +from datetime import timedelta + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import select, update + +from airflow._shared.timezones import timezone +from airflow.api_fastapi.dag_processing.app import create_dag_processing_api_app +from airflow.api_fastapi.dag_processing.security import require_dag_processing_auth +from airflow.callbacks.callback_requests import DagCallbackRequest +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.models.dagbundle import DagBundleModel +from airflow.models.db_callback_request import DbCallbackRequest +from airflow.models.errors import ParseImportError +from airflow.models.serialized_dag import SerializedDagModel + +from tests_common.test_utils.db import ( + clear_db_callbacks, + clear_db_dag_bundles, + clear_db_dag_parsing_requests, + clear_db_dags, + clear_db_import_errors, + clear_db_serialized_dags, +) + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def client(): + # The token check is covered in test_app.py; here we bypass it and exercise the real DB. + app = create_dag_processing_api_app() + app.dependency_overrides[require_dag_processing_auth] = lambda: None + return TestClient(app) + + +def _make_callback(*, bundle_name: str, run_id: str) -> DagCallbackRequest: + return DagCallbackRequest( + dag_id="test_dag", + bundle_name=bundle_name, + bundle_version=None, + filepath="some_dag.py", + is_failure_callback=True, + run_id=run_id, + ) + + +class TestClaimCallbacks: + def setup_method(self): + clear_db_callbacks() + + def teardown_method(self): + clear_db_callbacks() + + def test_claim_orders_by_priority_honors_limit_filters_bundle_and_deletes(self, client, session): + """POST /callbacks/claim returns the owned bundle's callbacks ordered by priority_weight + DESC up to ``limit``, deletes the claimed rows, and leaves the other bundle's row and the + rows beyond the limit untouched.""" + # Arrange: three callbacks for the owned bundle (varying priority) + one for another bundle. + session.add( + DbCallbackRequest(callback=_make_callback(bundle_name="owned", run_id="low"), priority_weight=1) + ) + session.add( + DbCallbackRequest(callback=_make_callback(bundle_name="owned", run_id="high"), priority_weight=10) + ) + session.add( + DbCallbackRequest(callback=_make_callback(bundle_name="owned", run_id="mid"), priority_weight=5) + ) + session.add( + DbCallbackRequest( + callback=_make_callback(bundle_name="other", run_id="other"), priority_weight=100 + ) + ) + session.commit() + + # Act: claim from the owned bundle with a limit of 2. + resp = client.post("/callbacks/claim", json={"bundle_names": ["owned"], "limit": 2}) + + # Assert: highest-priority owned callbacks first, count capped at the limit. + assert resp.status_code == 200 + claimed = resp.json()["callbacks"] + assert len(claimed) == 2 + run_ids = [json.loads(cb["req_data"])["run_id"] for cb in claimed] + assert run_ids == ["high", "mid"] + + # The claimed rows are deleted; the lower-priority owned row and the other bundle's row remain. + remaining = session.scalars(select(DbCallbackRequest)).all() + remaining_run_ids = {json.loads(cb.data["req_data"])["run_id"] for cb in remaining} + assert remaining_run_ids == {"low", "other"} + + +class TestDeactivateStaleDags: + def setup_method(self): + clear_db_serialized_dags() + clear_db_dags() + clear_db_dag_bundles() + + def teardown_method(self): + clear_db_serialized_dags() + clear_db_dags() + clear_db_dag_bundles() + + def test_marks_inactive_bundle_and_past_threshold_dags_keeps_fresh_and_serialized( + self, client, dag_maker, session + ): + """POST /stale-dags marks the inactive-bundle DAG and the past-threshold DAG stale, leaves + the freshly parsed DAG active, and preserves the SerializedDagModel row.""" + now = timezone.utcnow() + threshold = 60 + + # A freshly parsed DAG with a real SerializedDagModel, in an active bundle. + with dag_maker(dag_id="fresh_dag", bundle_name="live", serialized=True, session=session): + pass + dag_maker.dag_model.last_parsed_time = now + dag_maker.dag_model.relative_fileloc = "fresh.py" + dag_maker.dag_model.is_stale = False + session.merge(dag_maker.dag_model) + + # An inactive bundle plus a DAG that lives in it. + session.add(DagBundleModel(name="gone-bundle")) + session.flush() + session.execute( + update(DagBundleModel).where(DagBundleModel.name == "gone-bundle").values(active=False) + ) + session.add( + DagModel( + dag_id="dag_in_inactive_bundle", + bundle_name="gone-bundle", + relative_fileloc="inactive.py", + last_parsed_time=now, + is_stale=False, + ) + ) + # A DAG in the active bundle whose last parse predates last_finish_time by more than the threshold. + session.add( + DagModel( + dag_id="dag_past_threshold", + bundle_name="live", + relative_fileloc="stale.py", + last_parsed_time=now - timedelta(hours=2), + is_stale=False, + ) + ) + session.commit() + + # Act: last_finish_time for the live-bundle files is "now"; fresh.py's gap is under the + # threshold, stale.py's gap (2h) is over it. + body = { + "stale_dag_threshold": threshold, + "last_parsed": [ + {"bundle_name": "live", "relative_fileloc": "fresh.py", "last_finish_time": now.isoformat()}, + {"bundle_name": "live", "relative_fileloc": "stale.py", "last_finish_time": now.isoformat()}, + ], + } + resp = client.post("/stale-dags", json=body) + + # Assert: two DAGs deactivated; the fresh DAG stays active. + assert resp.status_code == 200 + assert resp.json()["deactivated"] == 2 + + is_stale_by_dag = dict( + session.execute( + select(DagModel.dag_id, DagModel.is_stale).where( + DagModel.dag_id.in_(["fresh_dag", "dag_in_inactive_bundle", "dag_past_threshold"]) + ) + ).all() + ) + assert is_stale_by_dag == { + "fresh_dag": False, + "dag_in_inactive_bundle": True, + "dag_past_threshold": True, + } + + # Deactivating DagModel rows must not delete the SerializedDagModel history. + serialized_count = session.scalar( + select(SerializedDagModel.dag_id).where(SerializedDagModel.dag_id == "fresh_dag") + ) + assert serialized_count == "fresh_dag" + + +class TestReconcileImportErrors: + def setup_method(self): + clear_db_import_errors() + clear_db_serialized_dags() + clear_db_dags() + + def teardown_method(self): + clear_db_import_errors() + clear_db_serialized_dags() + clear_db_dags() + + def test_deletes_absent_file_error_but_keeps_observed_and_zip_inner_errors(self, client, session): + """POST /bundles/{name}/reconcile deletes the import error for a file no longer observed, + while retaining the error for a still-observed file and for a zip inner path that is + observed (e.g. ``test_zip.zip/broken_dag.py``).""" + bundle = "testing" + # One error for a normal file that is still present, one zip-inner error that is still + # observed, and one for a file that is now absent from the bundle. + session.add( + ParseImportError( + filename="present_dag.py", + bundle_name=bundle, + timestamp=timezone.utcnow(), + stacktrace="present error", + ) + ) + session.add( + ParseImportError( + filename="test_zip.zip/broken_dag.py", + bundle_name=bundle, + timestamp=timezone.utcnow(), + stacktrace="zip import error", + ) + ) + session.add( + ParseImportError( + filename="absent_dag.py", + bundle_name=bundle, + timestamp=timezone.utcnow(), + stacktrace="absent error", + ) + ) + session.commit() + + # Act: observed set includes the normal file and the zip inner path, but not absent_dag.py. + resp = client.post( + f"/bundles/{bundle}/reconcile", + json={"observed_filelocs": ["present_dag.py", "test_zip.zip/broken_dag.py"]}, + ) + + # Assert: only the absent file's error is removed. + assert resp.status_code == 200 + remaining = { + err.filename + for err in session.scalars(select(ParseImportError).where(ParseImportError.bundle_name == bundle)) + } + assert remaining == {"present_dag.py", "test_zip.zip/broken_dag.py"} + + +class TestClaimPriorityParseRequests: + def setup_method(self): + clear_db_dag_parsing_requests() + + def teardown_method(self): + clear_db_dag_parsing_requests() + + def test_claims_owned_bundle_request_and_leaves_other_bundle_request(self, client, session): + """POST /priority-parse-requests/claim returns and deletes the owned bundle's request while + leaving a request that belongs to a bundle this processor does not own.""" + session.add(DagPriorityParsingRequest(bundle_name="owned", relative_fileloc="owned_file.py")) + session.add(DagPriorityParsingRequest(bundle_name="other", relative_fileloc="other_file.py")) + session.commit() + + # Act: claim only the owned bundle. + resp = client.post("/priority-parse-requests/claim", json={"bundle_names": ["owned"]}) + + # Assert: the owned request is returned and deleted; the other bundle's request remains. + assert resp.status_code == 200 + assert resp.json()["claimed"] == [{"bundle_name": "owned", "relative_fileloc": "owned_file.py"}] + + remaining = session.scalars(select(DagPriorityParsingRequest)).all() + assert len(remaining) == 1 + assert remaining[0].bundle_name == "other" + assert remaining[0].relative_fileloc == "other_file.py" From 19abc7cac9ef5c5f42a7d4b2b4dd1928b6d482a3 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 3 Jun 2026 00:38:44 +0100 Subject: [PATCH 3/7] Document the DAG processor's API-server requirement The standalone DAG processor now requires an API server that mounts the dag-processing app, and a deployment-provisioned bearer token. List dag-processing in the api-server --apps options, note the requirement in the web-stack docs, and add a significant newsfragment for the breaking change. --- .../web-stack.rst | 24 +++++++-- .../newsfragments/67878.significant.rst | 53 +++++++++++++++++++ airflow-core/src/airflow/cli/cli_config.py | 2 +- 3 files changed, 74 insertions(+), 5 deletions(-) create mode 100644 airflow-core/newsfragments/67878.significant.rst diff --git a/airflow-core/docs/administration-and-deployment/web-stack.rst b/airflow-core/docs/administration-and-deployment/web-stack.rst index fd32e75db5576..3c661d20b07cb 100644 --- a/airflow-core/docs/administration-and-deployment/web-stack.rst +++ b/airflow-core/docs/administration-and-deployment/web-stack.rst @@ -38,18 +38,24 @@ with the new prefix. Separating API Servers ----------------------- -By default, both the Core API Server and the Execution API Server are served together: +By default, all applications are served together: .. code-block:: bash airflow api-server # same as airflow api-server --apps all - # or + +``--apps all`` serves the Core API Server, the Execution API Server, and the DAG Processing +API Server. You can also select a subset of applications, for example to serve only the Core +and Execution API Servers: + +.. code-block:: bash + airflow api-server --apps core,execution -If you want to separate the Core API Server and the Execution API Server, you can run them -separately. This might be useful for scaling them independently or for deploying them on different machines. +If you want to separate the applications, you can run them separately. This might be useful for +scaling them independently or for deploying them on different machines. .. code-block:: bash @@ -57,6 +63,16 @@ separately. This might be useful for scaling them independently or for deploying airflow api-server --apps core # serve only the Execution API Server airflow api-server --apps execution + # serve only the DAG Processing API Server + airflow api-server --apps dag-processing + +.. note:: + + The standalone DAG processor (``airflow dag-processor``) reads and writes its metadata through + the DAG Processing API Server, so the API server it connects to must include the ``dag-processing`` + app. ``airflow api-server --apps all`` includes it; a subset such as ``--apps core,execution`` does + not. If the ``dag-processing`` app is missing, the DAG processor's requests return ``404`` and it + cannot run. Known Issues ------------ diff --git a/airflow-core/newsfragments/67878.significant.rst b/airflow-core/newsfragments/67878.significant.rst new file mode 100644 index 0000000000000..c2bf91ab2beca --- /dev/null +++ b/airflow-core/newsfragments/67878.significant.rst @@ -0,0 +1,53 @@ +DAG processor now reads and writes metadata exclusively through the API server + +The standalone DAG processor (``airflow dag-processor``) no longer connects to the metadata +database directly. It persists parse results and reads all metadata through the API server, +mirroring how workers and triggerers already operate. + +**What changed:** + +- The DAG processor process opens no metadata-database connection. Persistence (serialized + DAGs, import errors, warnings), stale-DAG and orphaned-import-error reconciliation, bundle + synchronization and state, priority-parse-request and callback claiming, and the processor's + own ``Job`` liveness record all go through the API server. +- Parse-time ``Connection``/``Variable``/``XCom`` reads from DAG files resolve through the + Execution API rather than an in-process app backed by the local database. + +**Behaviour changes:** + +- An API server must be reachable for the DAG processor to run. A deployment that previously + ran ``airflow dag-processor`` with only a database connection (no API server) must now also + run the API server. +- The API server fronting the DAG processor must run the ``dag-processing`` app. Start it with + ``airflow api-server --apps all`` or include ``dag-processing`` explicitly; a subset such as + ``--apps core,execution`` omits it and the DAG processor cannot run. +- The DAG processor authenticates to the API server with a bearer token the deployment + provisions; it does not mint its own. Because it parses user code, it is not given the signing + key. Provide the token via a file referenced by ``[dag_processor] api_token_path``. + +**Configuration:** + +- ``[core] dag_processing_api_server_url`` selects the DAG Processing API endpoint. It defaults + to the ``/dag-processing`` mount on the configured API server (``{BASE_URL}/dag-processing``), + so deployments that already run the API server need no extra configuration. Set it to point + the DAG processor at a different host. +- ``[dag_processor] api_token_path`` is the path to a file holding the bearer token the DAG + processor presents to the API server (for both the DAG Processing and Execution APIs). The + token is minted by the deployment or control plane, not the processor. + +**Migration:** + +- Ensure the API server is running and reachable from the DAG processor. If the DAG processor + runs on a host without the default API server, set ``[core] dag_processing_api_server_url`` + (and ``[core] execution_api_server_url`` for parse-time metadata reads) to its address. + +* Types of change + + * [ ] Dag changes + * [x] Config changes + * [x] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [x] Code interface changes + * [ ] API changes diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index c43d79d5bda49..5dd7692748308 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -752,7 +752,7 @@ def string_lower_type(val): ) ARG_API_SERVER_APPS = Arg( ("--apps",), - help="Applications to run (comma-separated). Default is all. Options: core, execution, all", + help="Applications to run (comma-separated). Default is all. Options: core, execution, dag-processing, all", default="all", ) ARG_API_SERVER_ALLOW_PROXY_FORWARDING = Arg( From 6a25dcaae33500482f0ef4cc5dacd1cb3ee553dc Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 3 Jun 2026 03:42:08 +0100 Subject: [PATCH 4/7] Provision the DAG processor's API token instead of self-signing it The DAG processor parses user code, so it must not hold the JWT signing key or mint its own token. It now carries a bearer token a trusted component provisions to the file at [dag_processor] api_token_path and only reads that file, re-reading it as the token is rotated so a refreshed token is picked up without a restart. - DagProcessingApiClient reads its token via a callable bearer-auth on each request (short cache plus a 401-triggered re-read), and the Execution API client used for parse-time Connection/Variable reads is re-read per parser spawn. - airflow.api_fastapi.auth.dag_processor_token mints the dual-audience token. It signs with whatever get_signing_args resolves (symmetric secret or an asymmetric private key), so a JWKS-based control plane validates externally-issued tokens unchanged. - 'airflow provision-dag-processor-token' is the trusted minter CLI; airflow standalone now provisions through it. [dag_processor] jwt_expiration_time sets the token lifetime. --- .../dagfile-processing.rst | 18 +++ .../newsfragments/67878.significant.rst | 16 ++- .../api_fastapi/auth/dag_processor_token.py | 103 ++++++++++++++++++ airflow-core/src/airflow/cli/cli_config.py | 18 +++ .../cli/commands/dag_processor_command.py | 16 +++ .../cli/commands/standalone_command.py | 23 +--- .../src/airflow/config_templates/config.yml | 15 ++- .../src/airflow/dag_processing/api_client.py | 74 ++++++++++++- .../src/airflow/dag_processing/manager.py | 8 +- .../auth/test_dag_processor_token.py | 85 +++++++++++++++ .../cli/commands/test_standalone_command.py | 8 +- .../unit/dag_processing/test_api_client.py | 93 ++++++++++++++++ 12 files changed, 445 insertions(+), 32 deletions(-) create mode 100644 airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py create mode 100644 airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py create mode 100644 airflow-core/tests/unit/dag_processing/test_api_client.py diff --git a/airflow-core/docs/administration-and-deployment/dagfile-processing.rst b/airflow-core/docs/administration-and-deployment/dagfile-processing.rst index b83c72171a225..99a98a84c5b09 100644 --- a/airflow-core/docs/administration-and-deployment/dagfile-processing.rst +++ b/airflow-core/docs/administration-and-deployment/dagfile-processing.rst @@ -45,6 +45,24 @@ The ``DagFileProcessorManager`` runs user codes. As a result, it runs as a stand 4. Return DagBag: Provide the ``DagFileProcessorManager`` a list of the discovered Dag objects +Communicating with the API server +---------------------------------- + +The Dag processor does not connect to the metadata database directly. It persists parse results and +reads metadata (including parse-time ``Connection``, ``Variable`` and ``XCom`` values) through the +API server, so an API server running the ``dag-processing`` app must be reachable. Start it with +``airflow api-server --apps all`` or include ``dag-processing`` explicitly; a subset such as +``--apps core,execution`` omits it and the Dag processor cannot run. + +Because the Dag processor parses user code, it must not hold the signing key or mint its own token. +A trusted component mints a bearer token and writes it to the file named by +:ref:`config:dag_processor__api_token_path`; the Dag processor only reads that file, re-reading it +as the token is rotated, so a refreshed token is picked up without a restart. Mint the token with +``airflow provision-dag-processor-token`` from a trusted context that holds the signing key, and +re-run it before :ref:`config:dag_processor__jwt_expiration_time` elapses. The official Helm chart +(via an init container) and the docker-compose example do this for you. + + Fine-tuning your Dag processor performance ------------------------------------------ diff --git a/airflow-core/newsfragments/67878.significant.rst b/airflow-core/newsfragments/67878.significant.rst index c2bf91ab2beca..7c221d643f334 100644 --- a/airflow-core/newsfragments/67878.significant.rst +++ b/airflow-core/newsfragments/67878.significant.rst @@ -23,7 +23,13 @@ mirroring how workers and triggerers already operate. ``--apps core,execution`` omits it and the DAG processor cannot run. - The DAG processor authenticates to the API server with a bearer token the deployment provisions; it does not mint its own. Because it parses user code, it is not given the signing - key. Provide the token via a file referenced by ``[dag_processor] api_token_path``. + key. A trusted component mints the token and writes it to the file referenced by + ``[dag_processor] api_token_path``; the processor only reads that file, re-reading it as the + token is rotated, so a refreshed token is picked up without a restart. +- ``airflow provision-dag-processor-token`` mints the token and writes it to that file. The + bundled Helm chart (an init container) and docker-compose example run it for you; in a custom + deployment, run it from a trusted context that holds the signing key before the processor starts, + and re-run it before ``[dag_processor] jwt_expiration_time`` elapses to rotate the token. **Configuration:** @@ -32,14 +38,18 @@ mirroring how workers and triggerers already operate. so deployments that already run the API server need no extra configuration. Set it to point the DAG processor at a different host. - ``[dag_processor] api_token_path`` is the path to a file holding the bearer token the DAG - processor presents to the API server (for both the DAG Processing and Execution APIs). The - token is minted by the deployment or control plane, not the processor. + processor presents to the API server (for both the DAG Processing and Execution APIs). +- ``[dag_processor] jwt_expiration_time`` is the lifetime of the minted token (default 24h); the + provisioning step should re-mint before it elapses. **Migration:** - Ensure the API server is running and reachable from the DAG processor. If the DAG processor runs on a host without the default API server, set ``[core] dag_processing_api_server_url`` (and ``[core] execution_api_server_url`` for parse-time metadata reads) to its address. +- Provision the DAG processor's token (``airflow provision-dag-processor-token``) from a trusted + context that holds the signing key, and point the processor at it with + ``[dag_processor] api_token_path``. The Helm chart and docker-compose example do this by default. * Types of change diff --git a/airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py b/airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py new file mode 100644 index 0000000000000..f632e0b651562 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Mint and provision the bearer token the DAG processor presents to the API server (AIP-92). + +The DAG processor parses (and forks) user code, so it must never hold the deployment signing key +or mint its own token. A *trusted* component runs the helpers here -- the deployment's provisioning +step (a Helm init container, a docker-compose init service) or ``airflow standalone`` -- mints the +token and writes it to ``[dag_processor] api_token_path``. The processor only ever reads that file +(re-reading it as it is rotated), so it carries a token without being able to forge one. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args +from airflow.configuration import conf + +log = logging.getLogger(__name__) + +# The Execution API is task-instance scoped: its ``sub`` is validated as a UUID. The DAG processor +# is not a task instance, so its token carries an all-zero sentinel UUID rather than a real id. +DAG_PROCESSOR_TOKEN_SUBJECT = "00000000-0000-0000-0000-000000000000" + + +def mint_dag_processor_token( + *, + valid_for: int | None = None, + make_secret_key_if_needed: bool = False, +) -> str: + """ + Mint the bearer token the DAG processor presents to the API server. + + Trusted callers only. The token carries *both* the Execution and DAG Processing audiences (the + processor calls both APIs) under :data:`DAG_PROCESSOR_TOKEN_SUBJECT`. Signing uses whatever + ``get_signing_args`` resolves -- the symmetric ``[api_auth] jwt_secret`` or, where configured, an + asymmetric private key -- so the same minting works for a single deployment or a JWKS-based + control plane without change here. + + :param valid_for: token lifetime in seconds; defaults to ``[dag_processor] jwt_expiration_time``. + :param make_secret_key_if_needed: generate a signing key if none is configured. Leave ``False`` + in real deployments (the key must be the shared one the API server validates with); only + ``airflow standalone``, which materialises that shared key itself, passes ``True``. + """ + audiences = [ + conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], + conf.get_mandatory_list_value("dag_processor", "jwt_audience")[0], + ] + generator = JWTGenerator( + valid_for=valid_for if valid_for is not None else conf.getint("dag_processor", "jwt_expiration_time"), + # A JWT ``aud`` may be a list; the generator's hint is single-audience, but the processor + # presents this one token to both the Execution and DAG Processing APIs. + audience=audiences, # type: ignore[arg-type] + issuer=conf.get("api_auth", "jwt_issuer", fallback=None), + **get_signing_args(make_secret_key_if_needed=make_secret_key_if_needed), + ) + return generator.generate({"sub": DAG_PROCESSOR_TOKEN_SUBJECT}) + + +def provision_dag_processor_token_file( + output: str | os.PathLike[str] | None = None, + *, + valid_for: int | None = None, + make_secret_key_if_needed: bool = False, +) -> str: + """ + Mint a DAG processor token and write it to ``output`` (or ``[dag_processor] api_token_path``). + + The token is written atomically (temp file then ``rename``) so the processor, which re-reads the + file, never observes a half-written token. Re-run before the token expires to rotate it in place. + + :return: the path written. + """ + target = output or conf.get("dag_processor", "api_token_path", fallback=None) + if not target: + raise ValueError( + "No output path for the DAG processor token: pass output or set [dag_processor] api_token_path." + ) + token = mint_dag_processor_token(valid_for=valid_for, make_secret_key_if_needed=make_secret_key_if_needed) + path = Path(target) + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_name(f"{path.name}.tmp") + tmp_path.write_text(token) + os.replace(tmp_path, path) + log.info("Wrote DAG processor API token to %s", path) + return str(path) diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index 5dd7692748308..a09da021f459f 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -218,6 +218,12 @@ def string_lower_type(val): type=str, default="[CWD]" if BUILD_DOCS else os.getcwd(), ) +ARG_DAG_PROCESSOR_TOKEN_OUTPUT = Arg( + ("--output",), + help="File to write the token to. Defaults to the `[dag_processor] api_token_path` config.", + type=str, + default=None, +) ARG_PID = Arg(("--pid",), help="PID file location", nargs="?") ARG_DAEMON = Arg( ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" @@ -2249,6 +2255,18 @@ class GroupCommand(NamedTuple): ARG_DEV, ), ), + ActionCommand( + name="provision-dag-processor-token", + help=( + "Mint the DAG processor's API token and write it to a file. Run by a trusted " + "deployment step (which holds the signing key), not by the processor itself." + ), + func=lazy_load_command("airflow.cli.commands.dag_processor_command.provision_dag_processor_token"), + args=( + ARG_DAG_PROCESSOR_TOKEN_OUTPUT, + ARG_VERBOSE, + ), + ), ActionCommand( name="version", help="Show the version", diff --git a/airflow-core/src/airflow/cli/commands/dag_processor_command.py b/airflow-core/src/airflow/cli/commands/dag_processor_command.py index 03e27436fe8df..c8c0a67005027 100644 --- a/airflow-core/src/airflow/cli/commands/dag_processor_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_processor_command.py @@ -22,6 +22,7 @@ import time from typing import Any +from airflow.api_fastapi.auth.dag_processor_token import provision_dag_processor_token_file from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.dag_processing.manager import DagFileProcessorManager from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner @@ -92,6 +93,21 @@ def _heartbeat() -> None: log.warning("Failed to mark DAG processor Job %s as %s", job_id, state, exc_info=True) +@cli_utils.action_cli +@providers_configuration_loaded +def provision_dag_processor_token(args): + """ + Mint the DAG processor's API token and write it to a file (trusted bootstrap step). + + Run by a trusted component (a deployment init step, not the processor itself), which holds the + signing key. Writes to ``--output`` or, by default, ``[dag_processor] api_token_path``. The DAG + processor then only reads that file. Re-run before ``[dag_processor] jwt_expiration_time`` + elapses to rotate the token in place. + """ + path = provision_dag_processor_token_file(getattr(args, "output", None)) + print(f"DAG processor API token written to {path}") + + @enable_memray_trace(component=MemrayTraceComponents.dag_processor) @cli_utils.action_cli @providers_configuration_loaded diff --git a/airflow-core/src/airflow/cli/commands/standalone_command.py b/airflow-core/src/airflow/cli/commands/standalone_command.py index 81f1215dd1b28..b6403c84b85d4 100644 --- a/airflow-core/src/airflow/cli/commands/standalone_command.py +++ b/airflow-core/src/airflow/cli/commands/standalone_command.py @@ -29,6 +29,8 @@ from termcolor import colored from airflow.api_fastapi.app import create_auth_manager +from airflow.api_fastapi.auth.dag_processor_token import provision_dag_processor_token_file +from airflow.api_fastapi.auth.tokens import get_signing_key from airflow.configuration import conf from airflow.executors import executor_constants from airflow.executors.executor_loader import ExecutorLoader @@ -205,8 +207,6 @@ def _provision_dag_processor_token(self, env: dict[str, str]) -> None: Processing audiences (the processor calls both APIs) and a sentinel subject, since the Execution API is task-instance scoped. """ - from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args, get_signing_key - try: # Materialise a signing key shared by every standalone subprocess, so the api-server # validates the token with the same key it is minted with here. @@ -214,22 +214,11 @@ def _provision_dag_processor_token(self, env: dict[str, str]) -> None: env["AIRFLOW__API_AUTH__JWT_SECRET"] = secret os.environ["AIRFLOW__API_AUTH__JWT_SECRET"] = secret - audiences = [ - conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], - conf.get_mandatory_list_value("dag_processor", "jwt_audience")[0], - ] - token = JWTGenerator( - valid_for=conf.getint("execution_api", "jwt_expiration_time"), - # A JWT ``aud`` may be a list; the generator's hint is single-audience, but the - # processor presents this one token to both the Execution and DAG Processing APIs. - audience=audiences, # type: ignore[arg-type] - issuer=conf.get("api_auth", "jwt_issuer", fallback=None), - **get_signing_args(make_secret_key_if_needed=True), - ).generate({"sub": "00000000-0000-0000-0000-000000000000"}) - + # Standalone is the trusted launcher, so it mints the processor's token here (the + # processor never mints its own); the same minting is used by real deployments. fd, path = tempfile.mkstemp(prefix="airflow-standalone-", suffix=".token") - with os.fdopen(fd, "w") as token_file: - token_file.write(token) + os.close(fd) + provision_dag_processor_token_file(path, make_secret_key_if_needed=True) env["AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH"] = path except Exception as e: self.print_output("standalone", f"Could not provision the DAG processor API token: {e}") diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 5d7548a880bd2..26a676a1d1262 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -3001,11 +3001,24 @@ dag_processor: Path to a file containing the bearer token the DAG processor presents to the API server (for both the DAG Processing and Execution APIs). The DAG processor parses user code, so it must not hold the signing key or mint its own token; the deployment (or control plane) - provisions this token and writes it to the file. If unset, the processor sends no token. + provisions this token and writes it to the file. The processor re-reads this file as it is + rotated, so a refreshed token is picked up without a restart. If unset, the processor sends + no token. example: "/var/run/airflow/dag-processor-token" default: ~ type: string + jwt_expiration_time: + version_added: 3.3.0 + description: | + Lifetime, in seconds, of the token a trusted component (the deployment's token-provisioning + step, or ``airflow standalone``) mints for the DAG processor and writes to ``api_token_path``. + The provisioning step should re-mint before this elapses so the long-running processor keeps + a valid token. The processor itself never mints a token; this only configures the minter. + default: "86400" + type: integer + example: ~ + dag_bundle_config_list: description: | List of backend configs. Must supply name, classpath, and kwargs for each backend. diff --git a/airflow-core/src/airflow/dag_processing/api_client.py b/airflow-core/src/airflow/dag_processing/api_client.py index f61c6c45a88cb..dc140b27bdfeb 100644 --- a/airflow-core/src/airflow/dag_processing/api_client.py +++ b/airflow-core/src/airflow/dag_processing/api_client.py @@ -28,7 +28,7 @@ import httpx if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable from airflow.dag_processing.processor import DagFileParsingResult @@ -43,6 +43,62 @@ _TRANSIENT_BACKOFF = 0.5 +def _parse_iso_datetime(value: str) -> datetime: + """ + Parse an ISO-8601 timestamp from the API server, tolerating a trailing ``Z``. + + The server emits UTC timestamps with a ``Z`` designator (e.g. ``2026-06-03T03:52:13.938753Z``), + but ``datetime.fromisoformat`` only accepts ``Z`` on Python 3.11+. Normalise it to an explicit + offset so parsing works on the oldest supported runtime (3.10). + """ + if value.endswith("Z"): + value = f"{value[:-1]}+00:00" + return datetime.fromisoformat(value) + + +# How long a token read from the provisioned file is reused before re-reading. Keeps the tight +# parse loop from stat-ing the file on every request while still picking up a rotated token quickly. +_TOKEN_CACHE_TTL = 30.0 + + +class _CallableBearerAuth(httpx.Auth): + """ + Attach a bearer token read from a callable, re-reading it as the deployment rotates it. + + The DAG processor never holds the signing key; a trusted component provisions its token to a + file (``[dag_processor] api_token_path``) and rotates it there before expiry. Reading the token + per-request (rather than baking it into the client once) lets a long-running processor pick up a + rotated token without restarting. The value is cached for ``_TOKEN_CACHE_TTL`` to avoid a file + read on every call; a ``401`` forces an immediate re-read and one retry, so a rotation that lands + mid-window is honoured without waiting out the cache. + """ + + def __init__(self, token_getter: Callable[[], str | None], *, cache_ttl: float = _TOKEN_CACHE_TTL): + self._token_getter = token_getter + self._cache_ttl = cache_ttl + self._cached: str | None = None + self._cached_at = 0.0 + + def _token(self, *, refresh: bool = False) -> str | None: + now = time.monotonic() + if refresh or self._cached is None or now - self._cached_at >= self._cache_ttl: + self._cached = self._token_getter() + self._cached_at = now + return self._cached + + def auth_flow(self, request: httpx.Request): + token = self._token() + if token: + request.headers["Authorization"] = f"Bearer {token}" + response = yield request + if response.status_code == 401: + # The token may have rotated on disk since the cached read; re-read and retry once. + fresh = self._token(refresh=True) + if fresh and fresh != token: + request.headers["Authorization"] = f"Bearer {fresh}" + yield request + + class DagProcessingApiClient: """ Forward DAG-processor persistence to the ``/dag-processing`` API sub-app. @@ -53,12 +109,20 @@ class DagProcessingApiClient: connections are pooled across the manager's parse loop. """ - def __init__(self, base_url: str, *, token: str | None = None, timeout: float = 30.0) -> None: + def __init__( + self, + base_url: str, + *, + token_getter: Callable[[], str | None] | None = None, + timeout: float = 30.0, + ) -> None: self._base_url = base_url.rstrip("/") - headers = {"Authorization": f"Bearer {token}"} if token else {} + # The token is read from ``token_getter`` per request (see _CallableBearerAuth), not baked + # in, so a token rotated on disk by the deployment is picked up without a restart. + auth = _CallableBearerAuth(token_getter) if token_getter is not None else None # ``retries`` retries connection failures (request never sent) at the transport level. self._client = httpx.Client( - headers=headers, + auth=auth, timeout=timeout, transport=httpx.HTTPTransport(retries=_CONNECT_RETRIES), ) @@ -157,7 +221,7 @@ def get_bundle_state(self, bundle_name: str) -> dict | None: return None last_refreshed = data.get("last_refreshed") return { - "last_refreshed": datetime.fromisoformat(last_refreshed) if last_refreshed else None, + "last_refreshed": _parse_iso_datetime(last_refreshed) if last_refreshed else None, "version": data.get("version"), } diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index 83192f7592136..31ad8bc1a60af 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -336,7 +336,7 @@ class DagFileProcessorManager(LoggingMixin): _dag_processing_client: DagProcessingApiClient = attrs.field( init=False, - factory=lambda: DagProcessingApiClient(_dag_processing_api_server_url(), token=_api_token()), + factory=lambda: DagProcessingApiClient(_dag_processing_api_server_url(), token_getter=_api_token), ) """Client for the DAG Processing API. The DAG processor never reads or writes the metadata database directly; all persistence and metadata reads are routed through the API server.""" @@ -1187,11 +1187,13 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo): underlying_logger, processors=processors, logger_name="processor" ).bind(), logger_filehandle - @functools.cached_property + @property def client(self) -> Client: # Parse-time connection/variable/xcom reads go to the remote Execution API, so the # processor holds no metadata-DB connection. It carries the externally-provisioned token - # (see _api_token); it does not mint one, since it parses user code. + # (see _api_token); it does not mint one, since it parses user code. Not cached: the token + # is re-read here so each parser process spawned later carries a token rotated on disk by + # the deployment, rather than a stale one baked in at manager start-up. return Client( base_url=get_execution_api_server_url(), token=_api_token() or "", diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py b/airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py new file mode 100644 index 0000000000000..67400109d51b7 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.api_fastapi.auth import dag_processor_token + +from tests_common.test_utils.config import conf_vars + + +class TestMintDagProcessorToken: + @mock.patch.object(dag_processor_token, "get_signing_args", return_value={"secret_key": "s"}) + @mock.patch.object(dag_processor_token, "JWTGenerator") + def test_carries_both_audiences_and_sentinel_subject(self, mock_generator, _signing_args): + mock_generator.return_value.generate.return_value = "minted-token" + with conf_vars( + { + ("execution_api", "jwt_audience"): "exec-aud", + ("dag_processor", "jwt_audience"): "dag-proc-aud", + ("dag_processor", "jwt_expiration_time"): "123", + } + ): + token = dag_processor_token.mint_dag_processor_token() + + assert token == "minted-token" + # Both audiences, since the processor presents this one token to both APIs. + assert mock_generator.call_args.kwargs["audience"] == ["exec-aud", "dag-proc-aud"] + assert mock_generator.call_args.kwargs["valid_for"] == 123 + mock_generator.return_value.generate.assert_called_once_with( + {"sub": dag_processor_token.DAG_PROCESSOR_TOKEN_SUBJECT} + ) + + @mock.patch.object(dag_processor_token, "get_signing_args", return_value={"secret_key": "s"}) + @mock.patch.object(dag_processor_token, "JWTGenerator") + def test_explicit_valid_for_overrides_config(self, mock_generator, _signing_args): + with conf_vars( + { + ("execution_api", "jwt_audience"): "exec-aud", + ("dag_processor", "jwt_audience"): "dag-proc-aud", + } + ): + dag_processor_token.mint_dag_processor_token(valid_for=42) + assert mock_generator.call_args.kwargs["valid_for"] == 42 + + +class TestProvisionDagProcessorTokenFile: + @mock.patch.object(dag_processor_token, "mint_dag_processor_token", return_value="minted") + def test_writes_to_explicit_path_and_creates_parents(self, _mint, tmp_path): + target = tmp_path / "nested" / "dag-processor-token" + + returned = dag_processor_token.provision_dag_processor_token_file(target) + + assert returned == str(target) + assert target.read_text() == "minted" + # Atomic write leaves no temp file behind. + assert not (target.parent / "dag-processor-token.tmp").exists() + + @mock.patch.object(dag_processor_token, "mint_dag_processor_token", return_value="minted") + def test_defaults_to_configured_api_token_path(self, _mint, tmp_path): + target = tmp_path / "token" + with conf_vars({("dag_processor", "api_token_path"): str(target)}): + dag_processor_token.provision_dag_processor_token_file() + assert target.read_text() == "minted" + + def test_raises_when_no_path_configured(self): + with conf_vars({("dag_processor", "api_token_path"): ""}): + with pytest.raises(ValueError, match="api_token_path"): + dag_processor_token.provision_dag_processor_token_file(None) diff --git a/airflow-core/tests/unit/cli/commands/test_standalone_command.py b/airflow-core/tests/unit/cli/commands/test_standalone_command.py index 8bb2b56b4e1ce..4149c5928a94e 100644 --- a/airflow-core/tests/unit/cli/commands/test_standalone_command.py +++ b/airflow-core/tests/unit/cli/commands/test_standalone_command.py @@ -85,9 +85,11 @@ class FakeExecutor: assert "AIRFLOW__CORE__AUTH_MANAGER" not in env - @mock.patch("airflow.api_fastapi.auth.tokens.JWTGenerator") - @mock.patch("airflow.api_fastapi.auth.tokens.get_signing_args", return_value={"secret_key": "s"}) - @mock.patch("airflow.api_fastapi.auth.tokens.get_signing_key", return_value="the-secret") + @mock.patch("airflow.api_fastapi.auth.dag_processor_token.JWTGenerator") + @mock.patch( + "airflow.api_fastapi.auth.dag_processor_token.get_signing_args", return_value={"secret_key": "s"} + ) + @mock.patch("airflow.cli.commands.standalone_command.get_signing_key", return_value="the-secret") def test_provision_dag_processor_token(self, _get_key, _get_args, mock_generator): """Standalone mints the processor's token and provisions it via env + a token file.""" mock_generator.return_value.generate.return_value = "minted-token" diff --git a/airflow-core/tests/unit/dag_processing/test_api_client.py b/airflow-core/tests/unit/dag_processing/test_api_client.py new file mode 100644 index 0000000000000..c22f6c1d5b70e --- /dev/null +++ b/airflow-core/tests/unit/dag_processing/test_api_client.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timezone + +import httpx +import pytest + +from airflow.dag_processing.api_client import _CallableBearerAuth, _parse_iso_datetime + + +class TestParseIsoDatetime: + @pytest.mark.parametrize( + "value", + [ + "2026-06-03T03:52:13.938753Z", # trailing Z (server format; rejected by 3.10 fromisoformat) + "2026-06-03T03:52:13.938753+00:00", # explicit offset + ], + ) + def test_parses_utc_with_or_without_z(self, value): + parsed = _parse_iso_datetime(value) + assert parsed == datetime(2026, 6, 3, 3, 52, 13, 938753, tzinfo=timezone.utc) + assert parsed.utcoffset() == timezone.utc.utcoffset(None) + + +def _drive_auth_flow(auth: _CallableBearerAuth, statuses: list[int]) -> list[str | None]: + """Drive ``auth.auth_flow`` the way httpx does, feeding it the given response statuses. + + Returns the ``Authorization`` header value sent for each request (snapshotted at send time, + since httpx mutates the one request object in place on a 401-triggered retry). A retry shows up + as a second entry. + """ + gen = auth.auth_flow(httpx.Request("GET", "http://host/path")) + sent_auth: list[str | None] = [] + try: + request = next(gen) + for status in statuses: + sent_auth.append(request.headers.get("Authorization")) + request = gen.send(httpx.Response(status, request=request)) + except StopIteration: + pass + return sent_auth + + +class TestCallableBearerAuth: + def test_attaches_bearer_token_from_getter(self): + auth = _CallableBearerAuth(lambda: "abc") + assert _drive_auth_flow(auth, [200]) == ["Bearer abc"] + + def test_no_header_when_token_is_none(self): + auth = _CallableBearerAuth(lambda: None) + assert _drive_auth_flow(auth, [200]) == [None] + + def test_token_is_cached_within_ttl(self): + calls = [] + + def getter(): + calls.append(1) + return "t" + + auth = _CallableBearerAuth(getter, cache_ttl=1000.0) + _drive_auth_flow(auth, [200]) + _drive_auth_flow(auth, [200]) + # Second request reused the cached token rather than re-reading. + assert len(calls) == 1 + + def test_401_rereads_token_and_retries_once(self): + tokens = iter(["stale", "fresh"]) + auth = _CallableBearerAuth(lambda: next(tokens), cache_ttl=1000.0) + + # First request carries the stale token; the 401 re-reads (bypassing the cache) and retries + # with the fresh one. + assert _drive_auth_flow(auth, [401, 200]) == ["Bearer stale", "Bearer fresh"] + + def test_no_retry_when_reread_token_unchanged(self): + auth = _CallableBearerAuth(lambda: "same", cache_ttl=1000.0) + # Re-read returned the same token, so there's nothing to retry with. + assert _drive_auth_flow(auth, [401, 200]) == ["Bearer same"] From b99b667bd2254fe47e4f7f62d94bb39641a44eb3 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 3 Jun 2026 03:42:47 +0100 Subject: [PATCH 5/7] Provision the DAG processor token in the Helm chart and docker-compose The DAG processor needs a bearer token but must not hold the signing key. These deployments now mint the token in a trusted step and share it with the processor: - Helm: an init container (which holds the signing key) mints the token into a shared emptyDir; the dag-processor container reads it read-only and is not given the key. Toggle with dagProcessor.apiToken.provisionViaInitContainer. - docker-compose (quick-start docs and task-sdk integration tests): airflow-init mints the token to a shared volume that the dag-processor reads. --- .../howto/docker-compose/docker-compose.yaml | 12 ++++++ .../dag-processor-deployment.yaml | 40 +++++++++++++++++++ .../airflow_core/test_dag_processor.py | 4 +- chart/values.schema.json | 17 ++++++++ chart/values.yaml | 13 ++++++ .../docker-compose.yaml | 10 +++++ 6 files changed, 95 insertions(+), 1 deletion(-) diff --git a/airflow-core/docs/howto/docker-compose/docker-compose.yaml b/airflow-core/docs/howto/docker-compose/docker-compose.yaml index 6a8891f3f58cc..0beb5554a7b41 100644 --- a/airflow-core/docs/howto/docker-compose/docker-compose.yaml +++ b/airflow-core/docs/howto/docker-compose/docker-compose.yaml @@ -66,6 +66,11 @@ x-airflow-common: AIRFLOW__CORE__EXECUTION_API_SERVER_URL: 'http://airflow-apiserver:8080/execution/' AIRFLOW__API_AUTH__JWT_SECRET: ${AIRFLOW__API_AUTH__JWT_SECRET:-airflow_jwt_secret} AIRFLOW__API_AUTH__JWT_ISSUER: ${AIRFLOW__API_AUTH__JWT_ISSUER:-airflow} + # The DAG processor parses user code, so it never mints its own API token: airflow-init mints + # one to this shared file and the processor reads it (re-reading it as it is rotated). For a + # hardened deployment, run the processor in a separate image that does not carry the signing + # key above; see the Helm chart for that pattern. + AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH: '/opt/airflow/api-token/dag-processor-token' # yamllint disable rule:line-length # Use simple http server on scheduler for health checks # See https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/logging-monitoring/check-health.html#scheduler-health-check-server @@ -81,6 +86,7 @@ x-airflow-common: - ${AIRFLOW_PROJ_DIR:-.}/logs:/opt/airflow/logs - ${AIRFLOW_PROJ_DIR:-.}/config:/opt/airflow/config - ${AIRFLOW_PROJ_DIR:-.}/plugins:/opt/airflow/plugins + - airflow-api-token:/opt/airflow/api-token user: "${AIRFLOW_UID:-50000}:0" depends_on: &airflow-common-depends-on @@ -270,6 +276,10 @@ services: echo /entrypoint airflow config list >/dev/null echo + echo "Minting the DAG processor API token (the processor never mints its own)." + echo + /entrypoint airflow provision-dag-processor-token + echo echo "Files in shared volumes:" echo ls -la /opt/airflow/{logs,dags,plugins,config} @@ -335,3 +345,5 @@ services: volumes: postgres-db-volume: + # Shares the DAG processor API token from airflow-init (which mints it) to the processor. + airflow-api-token: diff --git a/chart/templates/dag-processor/dag-processor-deployment.yaml b/chart/templates/dag-processor/dag-processor-deployment.yaml index 1f42c471afd0d..faf9a761bbb4e 100644 --- a/chart/templates/dag-processor/dag-processor-deployment.yaml +++ b/chart/templates/dag-processor/dag-processor-deployment.yaml @@ -31,6 +31,8 @@ {{- $containerSecurityContextLogGroomerSidecar := include "containerSecurityContext" (list .Values.dagProcessor.logGroomerSidecar .Values) }} {{- $containerSecurityContextWaitForMigrations := include "containerSecurityContext" (list .Values.dagProcessor.waitForMigrations .Values) }} {{- $containerLifecycleHooks := or .Values.dagProcessor.containerLifecycleHooks .Values.containerLifecycleHooks }} +{{- $provisionApiToken := .Values.dagProcessor.apiToken.provisionViaInitContainer }} +{{- $apiTokenDir := dir .Values.dagProcessor.apiToken.path }} apiVersion: apps/v1 kind: Deployment metadata: @@ -137,6 +139,35 @@ spec: {{- tpl (toYaml .Values.dagProcessor.waitForMigrations.env) $ | nindent 12 }} {{- end }} {{- end }} + {{- if $provisionApiToken }} + - name: provision-api-token + # The DAG processor parses user code, so it must not hold the signing key. This trusted + # init container holds it (IncludeJwtSecret true), mints the processor's token into a + # shared volume, and exits; the dag-processor container below reads that file without the + # signing key. Re-created on every pod start, so the token is fresh per restart. + resources: {{- toYaml .Values.dagProcessor.resources | nindent 12 }} + image: {{ template "airflow_image" . }} + imagePullPolicy: {{ .Values.images.airflow.pullPolicy }} + securityContext: {{ $containerSecurityContext | nindent 12 }} + args: + - "bash" + - "-c" + - "exec airflow provision-dag-processor-token" + volumeMounts: + {{- if .Values.volumeMounts }} + {{- toYaml .Values.volumeMounts | nindent 12 }} + {{- end }} + {{- if .Values.dagProcessor.extraVolumeMounts }} + {{- tpl (toYaml .Values.dagProcessor.extraVolumeMounts) . | nindent 12 }} + {{- end }} + - name: dag-processor-api-token + mountPath: {{ $apiTokenDir | quote }} + {{- include "airflow_config_mount" . | nindent 12 }} + envFrom: {{- include "custom_airflow_environment_from" . | default "\n []" | indent 10 }} + env: + {{- include "custom_airflow_environment" . | indent 10 }} + {{- include "standard_airflow_environment" (merge (dict "IncludeJwtSecret" true) .) | indent 10 }} + {{- end }} {{- if and .Values.dags.gitSync.enabled (not .Values.dags.persistence.enabled) }} {{- include "git_sync_container" (dict "Values" .Values "is_init" "true" "Template" .Template) | nindent 8 }} {{- end }} @@ -170,6 +201,11 @@ spec: {{- if .Values.logs.persistence.subPath }} subPath: {{ .Values.logs.persistence.subPath }} {{- end }} + {{- if $provisionApiToken }} + - name: dag-processor-api-token + mountPath: {{ $apiTokenDir | quote }} + readOnly: true + {{- end }} {{- include "airflow_config_mount" . | nindent 12 }} {{- if or .Values.dags.persistence.enabled .Values.dags.gitSync.enabled }} {{- include "airflow_dags_mount" . | nindent 12 }} @@ -277,4 +313,8 @@ spec: - name: logs emptyDir: {{- toYaml (default (dict) .Values.logs.emptyDirConfig) | nindent 12 }} {{- end }} + {{- if $provisionApiToken }} + - name: dag-processor-api-token + emptyDir: {} + {{- end }} {{- end }} diff --git a/chart/tests/helm_tests/airflow_core/test_dag_processor.py b/chart/tests/helm_tests/airflow_core/test_dag_processor.py index 9bc1f6ebb2e84..b92294f4a1aa8 100644 --- a/chart/tests/helm_tests/airflow_core/test_dag_processor.py +++ b/chart/tests/helm_tests/airflow_core/test_dag_processor.py @@ -63,7 +63,9 @@ def test_disable_wait_for_migration(self): actual = jmespath.search( "spec.template.spec.initContainers[?name=='wait-for-airflow-migrations']", docs[0] ) - assert actual is None + # No wait-for-migrations init container (other init containers, e.g. the API token minter, + # may still be present, so this is an empty match rather than a missing initContainers list). + assert not actual def test_wait_for_migration_security_contexts_are_configurable(self): docs = render_chart( diff --git a/chart/values.schema.json b/chart/values.schema.json index bc094f6b5fb67..152ef13743bad 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -5742,6 +5742,23 @@ "type": "boolean", "default": true }, + "apiToken": { + "description": "Provisioning of the bearer token the DAG processor presents to the API server. The processor never mints its own token (it parses user code); a trusted init container mints it instead.", + "type": "object", + "additionalProperties": false, + "properties": { + "provisionViaInitContainer": { + "description": "Mint the token in an init container (which holds the signing key) into a shared volume the processor reads. Set false to provision the token yourself via ``extraInitContainers``/``extraVolumes`` and ``config.dag_processor.api_token_path``.", + "type": "boolean", + "default": true + }, + "path": { + "description": "Path of the token file, shared between the init container and the processor.", + "type": "string", + "default": "/opt/airflow/dag-processor-token/token" + } + } + }, "dagBundleConfigList": { "description": "Define Dag bundles in a structured YAML format. This will be automatically converted to JSON string format for config.dag_processor.dag_bundle_config_list.", "type": "array", diff --git a/chart/values.yaml b/chart/values.yaml index 4213d755884cb..1243f58fea43e 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -2687,6 +2687,16 @@ triggerer: dagProcessor: enabled: true + # The DAG processor authenticates to the API server with a bearer token but never mints its own + # (it parses user code, so it must not hold the signing key). By default a trusted init container + # mints one -- using the deployment signing key -- into a shared volume that the processor reads + # and re-reads as it is rotated. Set provisionViaInitContainer to false to provision the token + # yourself via extraInitContainers/extraVolumes plus `config.dag_processor.api_token_path`. + apiToken: + provisionViaInitContainer: true + # Where the minted token is written and read, shared between the init container and the processor. + path: /opt/airflow/dag-processor-token/token + # Dag Bundle Configuration # Define Dag bundles in a structured YAML format. This will be automatically # converted to JSON string format for `config.dag_processor.dag_bundle_config_list`. @@ -3947,6 +3957,9 @@ config: # This value is generated by default from `.Values.dagProcessor.dagBundleConfigList` using the `dag_bundle_config_list` helper function. # It is recommended to configure this via `dagProcessor.dagBundleConfigList` rather than overriding `config.dag_processor.dag_bundle_config_list` directly. dag_bundle_config_list: '{{ include "dag_bundle_config_list" . }}' + # Points the DAG processor at the token the init container mints (see dagProcessor.apiToken). + # Set in config (not as an env var) so it is read by both the init container and the processor. + api_token_path: '{{ if .Values.dagProcessor.apiToken.provisionViaInitContainer }}{{ .Values.dagProcessor.apiToken.path }}{{ end }}' elasticsearch: json_format: 'True' log_id_template: "{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}" diff --git a/task-sdk-integration-tests/docker-compose.yaml b/task-sdk-integration-tests/docker-compose.yaml index 82143e98b5e4a..402ef265bee8f 100644 --- a/task-sdk-integration-tests/docker-compose.yaml +++ b/task-sdk-integration-tests/docker-compose.yaml @@ -34,6 +34,9 @@ x-airflow-common: AIRFLOW__CORE__EXECUTION_API_SERVER_URL: 'http://airflow-apiserver:8080/execution/' AIRFLOW__API__BASE_URL: 'http://airflow-apiserver:8080/' AIRFLOW__API_AUTH__JWT_SECRET: 'test-secret-key-for-testing' + # The DAG processor never mints its own token (it parses user code): airflow-init mints one to + # this shared file and the processor reads it. See airflow-init / [dag_processor] api_token_path. + AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH: '/opt/airflow/api-token/dag-processor-token' AIRFLOW_VAR_TEST_VARIABLE_KEY: 'test_variable_value' AIRFLOW_CONN_TEST_CONNECTION: 'postgresql://testuser:testpass@testhost:5432/testdb' HOST_OS: ${HOST_OS:-linux} @@ -41,6 +44,7 @@ x-airflow-common: volumes: - ./dags:/opt/airflow/dags - ./logs:/opt/airflow/logs + - airflow-api-token:/opt/airflow/api-token depends_on: &airflow-common-depends-on postgres: @@ -72,6 +76,8 @@ services: /entrypoint airflow version echo "Running airflow config list to create default config file if missing." /entrypoint airflow config list >/dev/null + echo "Minting the DAG processor API token (processor parses user code, so it never mints its own)." + /entrypoint airflow provision-dag-processor-token if [ "${HOST_OS}" == "linux" ]; then echo "Change ownership of files in /opt/airflow to ${AIRFLOW_UID}:0" chown -R "${AIRFLOW_UID}:0" /opt/airflow/ @@ -128,3 +134,7 @@ services: <<: *airflow-common-depends-on airflow-init: condition: service_completed_successfully + +volumes: + # Shares the DAG processor API token from airflow-init (which mints it) to the processor. + airflow-api-token: From 327717c8d27f4958dd2639ade28b2d0ab343fa99 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 3 Jun 2026 03:43:07 +0100 Subject: [PATCH 6/7] Stabilize test_stats_total_parse_time against the reparse interval The test ran the manager three times expecting one parse per run, relying on dag_path.touch() to beat the default 30s min_file_process_interval via an mtime comparison. Under load that filesystem-granularity race left a run waiting out the interval and tripping the per-test timeout. Pin min_file_process_interval=0 so each run re-parses unconditionally; the touch is no longer needed. --- airflow-core/tests/unit/dag_processing/test_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 24fe6c5b7392a..9128b9472f14a 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -1899,6 +1899,11 @@ def test_run_parsing_loop_uses_overridable_purge(self, tmp_path, configure_testi manager.run() purge_mock.assert_called() + # min_file_process_interval=0 makes each file unconditionally eligible for re-parsing, so the + # three successive runs below don't hinge on the file's mtime exceeding its last-parsed time + # (a filesystem-granularity race that, under load, left a run waiting out the default 30s + # interval and blew the per-test timeout). + @conf_vars({("dag_processor", "min_file_process_interval"): "0"}) @mock.patch("airflow.dag_processing.manager.stats.gauge") def test_stats_total_parse_time(self, statsd_gauge_mock, tmp_path, configure_testing_dag_bundle): key = "dag_processing.total_parse_time" @@ -1925,7 +1930,6 @@ def test_stats_total_parse_time(self, statsd_gauge_mock, tmp_path, configure_tes assert len(gauge_values[key]) == 1 assert gauge_values[key][0] >= 1e-4 - dag_path.touch() # make the loop run faster gauge_values.clear() def _make_refresh_bundle(self, *, supports_versioning=False, current_version=None): From 5e76e706fcd89556c5744a5909d8bf37989a4ae2 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 4 Jun 2026 21:50:38 +0100 Subject: [PATCH 7/7] Register the DAG processor's liveness Job with the processor's own hostname The /dag-processing /jobs endpoint built the Job server-side, so it recorded the API server's hostname/pid rather than the processor's. The dag-processor health check (airflow jobs check --job-type DagProcessorJob --hostname ) filters by the processor's hostname, so it never matched the row and 'docker compose up --wait' timed out waiting for the processor to become healthy (the e2e and remote-logging PROD-image tests). The processor now reports its hostname/unixname/pid when registering and the endpoint records them, restoring the in-process behaviour. --- .../airflow/api_fastapi/dag_processing/app.py | 8 ++++++++ .../api_fastapi/dag_processing/datamodels.py | 10 +++++++++- .../src/airflow/dag_processing/api_client.py | 17 ++++++++++++++++- .../api_fastapi/dag_processing/test_app.py | 19 ++++++++++++++++++- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/app.py b/airflow-core/src/airflow/api_fastapi/dag_processing/app.py index e7e04dbd014f6..689e374948a7f 100644 --- a/airflow-core/src/airflow/api_fastapi/dag_processing/app.py +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/app.py @@ -264,6 +264,14 @@ def register_job(body: JobRegisterBody) -> dict: """Register the processor's liveness Job row (server-side) and return its id.""" job = Job() job.job_type = body.job_type + # ``Job()`` defaults hostname/unixname/pid to *this* API server's, but the processor runs in a + # different process (and usually host). Record the identity it reported so its health check + # (``airflow jobs check --job-type DagProcessorJob --hostname ``) matches this row. + job.hostname = body.hostname + if body.unixname: + job.unixname = body.unixname + if body.pid is not None: + job.pid = body.pid with create_session() as session: job.prepare_for_execution(session=session) return {"job_id": job.id} diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py b/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py index 130edb0adf739..494d9e96d6ab7 100644 --- a/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py @@ -86,9 +86,17 @@ class CallbackClaimBody(BaseModel): class JobRegisterBody(BaseModel): - """Job type to register for the processor's liveness record.""" + """ + Identity of the processor whose liveness Job is being registered. + + ``hostname``/``unixname``/``pid`` are the *processor's* (not this API server's), so the Job row + reflects where the processor runs and ``airflow jobs check --hostname`` matches it. + """ job_type: str + hostname: str + unixname: str | None = None + pid: int | None = None class JobCompleteBody(BaseModel): diff --git a/airflow-core/src/airflow/dag_processing/api_client.py b/airflow-core/src/airflow/dag_processing/api_client.py index dc140b27bdfeb..295b035409b4f 100644 --- a/airflow-core/src/airflow/dag_processing/api_client.py +++ b/airflow-core/src/airflow/dag_processing/api_client.py @@ -19,14 +19,18 @@ from __future__ import annotations import logging +import os import time from datetime import datetime +from getpass import getuser from importlib import import_module from typing import TYPE_CHECKING, Any from urllib.parse import quote import httpx +from airflow.utils.net import get_hostname + if TYPE_CHECKING: from collections.abc import Callable, Iterable @@ -268,7 +272,18 @@ def fetch_callbacks(self, *, bundle_names: list[str], limit: int) -> list: def register_job(self, job_type: str) -> int: """Register the processor's liveness Job row server-side and return its id.""" - resp = self._send("POST", "/jobs", json={"job_type": job_type}) + # Report this processor's own identity so the Job row reflects where the processor runs, + # not the API server; the health check matches on hostname (see the /jobs endpoint). + resp = self._send( + "POST", + "/jobs", + json={ + "job_type": job_type, + "hostname": get_hostname(), + "unixname": getuser(), + "pid": os.getpid(), + }, + ) return resp.json()["job_id"] def job_heartbeat(self, job_id: int) -> None: diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py index 84b133860a13a..f1a44427c8ab0 100644 --- a/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py @@ -330,11 +330,28 @@ def test_claim_callbacks_skip_locked_and_delete(client): def test_register_job(client): job = mock.MagicMock(id=42) with mock.patch(f"{APP}.Job", return_value=job), mock.patch(f"{APP}.create_session"): - resp = client.post("/jobs", json={"job_type": "DagProcessorJob"}) + resp = client.post( + "/jobs", + json={ + "job_type": "DagProcessorJob", + "hostname": "dag-proc-host", + "unixname": "airflow", + "pid": 1234, + }, + ) assert resp.status_code == 201 assert resp.json() == {"job_id": 42} job.prepare_for_execution.assert_called_once() + # The processor's reported identity is recorded (not this server's), so its health check matches. + assert job.hostname == "dag-proc-host" + assert job.pid == 1234 + + +def test_register_job_requires_hostname(client): + """hostname is required so the Job row reflects the processor, not the API server.""" + resp = client.post("/jobs", json={"job_type": "DagProcessorJob"}) + assert resp.status_code == 422 def test_job_heartbeat_updates_latest_heartbeat(client):