diff --git a/airflow-core/docs/administration-and-deployment/dagfile-processing.rst b/airflow-core/docs/administration-and-deployment/dagfile-processing.rst index b83c72171a225..99a98a84c5b09 100644 --- a/airflow-core/docs/administration-and-deployment/dagfile-processing.rst +++ b/airflow-core/docs/administration-and-deployment/dagfile-processing.rst @@ -45,6 +45,24 @@ The ``DagFileProcessorManager`` runs user codes. As a result, it runs as a stand 4. Return DagBag: Provide the ``DagFileProcessorManager`` a list of the discovered Dag objects +Communicating with the API server +---------------------------------- + +The Dag processor does not connect to the metadata database directly. It persists parse results and +reads metadata (including parse-time ``Connection``, ``Variable`` and ``XCom`` values) through the +API server, so an API server running the ``dag-processing`` app must be reachable. Start it with +``airflow api-server --apps all`` or include ``dag-processing`` explicitly; a subset such as +``--apps core,execution`` omits it and the Dag processor cannot run. + +Because the Dag processor parses user code, it must not hold the signing key or mint its own token. +A trusted component mints a bearer token and writes it to the file named by +:ref:`config:dag_processor__api_token_path`; the Dag processor only reads that file, re-reading it +as the token is rotated, so a refreshed token is picked up without a restart. Mint the token with +``airflow provision-dag-processor-token`` from a trusted context that holds the signing key, and +re-run it before :ref:`config:dag_processor__jwt_expiration_time` elapses. The official Helm chart +(via an init container) and the docker-compose example do this for you. + + Fine-tuning your Dag processor performance ------------------------------------------ diff --git a/airflow-core/docs/administration-and-deployment/web-stack.rst b/airflow-core/docs/administration-and-deployment/web-stack.rst index fd32e75db5576..3c661d20b07cb 100644 --- a/airflow-core/docs/administration-and-deployment/web-stack.rst +++ b/airflow-core/docs/administration-and-deployment/web-stack.rst @@ -38,18 +38,24 @@ with the new prefix. Separating API Servers ----------------------- -By default, both the Core API Server and the Execution API Server are served together: +By default, all applications are served together: .. code-block:: bash airflow api-server # same as airflow api-server --apps all - # or + +``--apps all`` serves the Core API Server, the Execution API Server, and the DAG Processing +API Server. You can also select a subset of applications, for example to serve only the Core +and Execution API Servers: + +.. code-block:: bash + airflow api-server --apps core,execution -If you want to separate the Core API Server and the Execution API Server, you can run them -separately. This might be useful for scaling them independently or for deploying them on different machines. +If you want to separate the applications, you can run them separately. This might be useful for +scaling them independently or for deploying them on different machines. .. code-block:: bash @@ -57,6 +63,16 @@ separately. This might be useful for scaling them independently or for deploying airflow api-server --apps core # serve only the Execution API Server airflow api-server --apps execution + # serve only the DAG Processing API Server + airflow api-server --apps dag-processing + +.. note:: + + The standalone DAG processor (``airflow dag-processor``) reads and writes its metadata through + the DAG Processing API Server, so the API server it connects to must include the ``dag-processing`` + app. ``airflow api-server --apps all`` includes it; a subset such as ``--apps core,execution`` does + not. If the ``dag-processing`` app is missing, the DAG processor's requests return ``404`` and it + cannot run. Known Issues ------------ diff --git a/airflow-core/docs/howto/docker-compose/docker-compose.yaml b/airflow-core/docs/howto/docker-compose/docker-compose.yaml index 6a8891f3f58cc..0beb5554a7b41 100644 --- a/airflow-core/docs/howto/docker-compose/docker-compose.yaml +++ b/airflow-core/docs/howto/docker-compose/docker-compose.yaml @@ -66,6 +66,11 @@ x-airflow-common: AIRFLOW__CORE__EXECUTION_API_SERVER_URL: 'http://airflow-apiserver:8080/execution/' AIRFLOW__API_AUTH__JWT_SECRET: ${AIRFLOW__API_AUTH__JWT_SECRET:-airflow_jwt_secret} AIRFLOW__API_AUTH__JWT_ISSUER: ${AIRFLOW__API_AUTH__JWT_ISSUER:-airflow} + # The DAG processor parses user code, so it never mints its own API token: airflow-init mints + # one to this shared file and the processor reads it (re-reading it as it is rotated). For a + # hardened deployment, run the processor in a separate image that does not carry the signing + # key above; see the Helm chart for that pattern. + AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH: '/opt/airflow/api-token/dag-processor-token' # yamllint disable rule:line-length # Use simple http server on scheduler for health checks # See https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/logging-monitoring/check-health.html#scheduler-health-check-server @@ -81,6 +86,7 @@ x-airflow-common: - ${AIRFLOW_PROJ_DIR:-.}/logs:/opt/airflow/logs - ${AIRFLOW_PROJ_DIR:-.}/config:/opt/airflow/config - ${AIRFLOW_PROJ_DIR:-.}/plugins:/opt/airflow/plugins + - airflow-api-token:/opt/airflow/api-token user: "${AIRFLOW_UID:-50000}:0" depends_on: &airflow-common-depends-on @@ -270,6 +276,10 @@ services: echo /entrypoint airflow config list >/dev/null echo + echo "Minting the DAG processor API token (the processor never mints its own)." + echo + /entrypoint airflow provision-dag-processor-token + echo echo "Files in shared volumes:" echo ls -la /opt/airflow/{logs,dags,plugins,config} @@ -335,3 +345,5 @@ services: volumes: postgres-db-volume: + # Shares the DAG processor API token from airflow-init (which mints it) to the processor. + airflow-api-token: diff --git a/airflow-core/newsfragments/67878.significant.rst b/airflow-core/newsfragments/67878.significant.rst new file mode 100644 index 0000000000000..7c221d643f334 --- /dev/null +++ b/airflow-core/newsfragments/67878.significant.rst @@ -0,0 +1,63 @@ +DAG processor now reads and writes metadata exclusively through the API server + +The standalone DAG processor (``airflow dag-processor``) no longer connects to the metadata +database directly. It persists parse results and reads all metadata through the API server, +mirroring how workers and triggerers already operate. + +**What changed:** + +- The DAG processor process opens no metadata-database connection. Persistence (serialized + DAGs, import errors, warnings), stale-DAG and orphaned-import-error reconciliation, bundle + synchronization and state, priority-parse-request and callback claiming, and the processor's + own ``Job`` liveness record all go through the API server. +- Parse-time ``Connection``/``Variable``/``XCom`` reads from DAG files resolve through the + Execution API rather than an in-process app backed by the local database. + +**Behaviour changes:** + +- An API server must be reachable for the DAG processor to run. A deployment that previously + ran ``airflow dag-processor`` with only a database connection (no API server) must now also + run the API server. +- The API server fronting the DAG processor must run the ``dag-processing`` app. Start it with + ``airflow api-server --apps all`` or include ``dag-processing`` explicitly; a subset such as + ``--apps core,execution`` omits it and the DAG processor cannot run. +- The DAG processor authenticates to the API server with a bearer token the deployment + provisions; it does not mint its own. Because it parses user code, it is not given the signing + key. A trusted component mints the token and writes it to the file referenced by + ``[dag_processor] api_token_path``; the processor only reads that file, re-reading it as the + token is rotated, so a refreshed token is picked up without a restart. +- ``airflow provision-dag-processor-token`` mints the token and writes it to that file. The + bundled Helm chart (an init container) and docker-compose example run it for you; in a custom + deployment, run it from a trusted context that holds the signing key before the processor starts, + and re-run it before ``[dag_processor] jwt_expiration_time`` elapses to rotate the token. + +**Configuration:** + +- ``[core] dag_processing_api_server_url`` selects the DAG Processing API endpoint. It defaults + to the ``/dag-processing`` mount on the configured API server (``{BASE_URL}/dag-processing``), + so deployments that already run the API server need no extra configuration. Set it to point + the DAG processor at a different host. +- ``[dag_processor] api_token_path`` is the path to a file holding the bearer token the DAG + processor presents to the API server (for both the DAG Processing and Execution APIs). +- ``[dag_processor] jwt_expiration_time`` is the lifetime of the minted token (default 24h); the + provisioning step should re-mint before it elapses. + +**Migration:** + +- Ensure the API server is running and reachable from the DAG processor. If the DAG processor + runs on a host without the default API server, set ``[core] dag_processing_api_server_url`` + (and ``[core] execution_api_server_url`` for parse-time metadata reads) to its address. +- Provision the DAG processor's token (``airflow provision-dag-processor-token``) from a trusted + context that holds the signing key, and point the processor at it with + ``[dag_processor] api_token_path``. The Helm chart and docker-compose example do this by default. + +* Types of change + + * [ ] Dag changes + * [x] Config changes + * [x] CLI changes + * [x] Behaviour changes + * [ ] Plugin changes + * [ ] Dependency changes + * [x] Code interface changes + * [ ] API changes diff --git a/airflow-core/src/airflow/api_fastapi/app.py b/airflow-core/src/airflow/api_fastapi/app.py index 8931840c8807c..a8425658407b8 100644 --- a/airflow-core/src/airflow/api_fastapi/app.py +++ b/airflow-core/src/airflow/api_fastapi/app.py @@ -34,6 +34,7 @@ init_middlewares, init_views, ) +from airflow.api_fastapi.dag_processing.app import create_dag_processing_api_app from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.configuration import conf from airflow.exceptions import AirflowConfigException @@ -61,7 +62,7 @@ def get_cookie_path() -> str: # Fast API apps mounted under these prefixes are not allowed -RESERVED_URL_PREFIXES = ["/api/v2", "/ui", "/execution", "/auth", "/pluginsv2"] +RESERVED_URL_PREFIXES = ["/api/v2", "/ui", "/execution", "/auth", "/pluginsv2", "/dag-processing"] log = logging.getLogger(__name__) @@ -107,6 +108,9 @@ def create_app(apps: str = "all") -> FastAPI: init_error_handlers(task_exec_api_app) app.mount("/execution", task_exec_api_app) + if "all" in apps_list or "dag-processing" in apps_list: + app.mount("/dag-processing", create_dag_processing_api_app()) + if "all" in apps_list or "core" in apps_list: app.state.dag_bag = dag_bag init_plugins(app) diff --git a/airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py b/airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py new file mode 100644 index 0000000000000..f632e0b651562 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/auth/dag_processor_token.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Mint and provision the bearer token the DAG processor presents to the API server (AIP-92). + +The DAG processor parses (and forks) user code, so it must never hold the deployment signing key +or mint its own token. A *trusted* component runs the helpers here -- the deployment's provisioning +step (a Helm init container, a docker-compose init service) or ``airflow standalone`` -- mints the +token and writes it to ``[dag_processor] api_token_path``. The processor only ever reads that file +(re-reading it as it is rotated), so it carries a token without being able to forge one. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args +from airflow.configuration import conf + +log = logging.getLogger(__name__) + +# The Execution API is task-instance scoped: its ``sub`` is validated as a UUID. The DAG processor +# is not a task instance, so its token carries an all-zero sentinel UUID rather than a real id. +DAG_PROCESSOR_TOKEN_SUBJECT = "00000000-0000-0000-0000-000000000000" + + +def mint_dag_processor_token( + *, + valid_for: int | None = None, + make_secret_key_if_needed: bool = False, +) -> str: + """ + Mint the bearer token the DAG processor presents to the API server. + + Trusted callers only. The token carries *both* the Execution and DAG Processing audiences (the + processor calls both APIs) under :data:`DAG_PROCESSOR_TOKEN_SUBJECT`. Signing uses whatever + ``get_signing_args`` resolves -- the symmetric ``[api_auth] jwt_secret`` or, where configured, an + asymmetric private key -- so the same minting works for a single deployment or a JWKS-based + control plane without change here. + + :param valid_for: token lifetime in seconds; defaults to ``[dag_processor] jwt_expiration_time``. + :param make_secret_key_if_needed: generate a signing key if none is configured. Leave ``False`` + in real deployments (the key must be the shared one the API server validates with); only + ``airflow standalone``, which materialises that shared key itself, passes ``True``. + """ + audiences = [ + conf.get_mandatory_list_value("execution_api", "jwt_audience")[0], + conf.get_mandatory_list_value("dag_processor", "jwt_audience")[0], + ] + generator = JWTGenerator( + valid_for=valid_for if valid_for is not None else conf.getint("dag_processor", "jwt_expiration_time"), + # A JWT ``aud`` may be a list; the generator's hint is single-audience, but the processor + # presents this one token to both the Execution and DAG Processing APIs. + audience=audiences, # type: ignore[arg-type] + issuer=conf.get("api_auth", "jwt_issuer", fallback=None), + **get_signing_args(make_secret_key_if_needed=make_secret_key_if_needed), + ) + return generator.generate({"sub": DAG_PROCESSOR_TOKEN_SUBJECT}) + + +def provision_dag_processor_token_file( + output: str | os.PathLike[str] | None = None, + *, + valid_for: int | None = None, + make_secret_key_if_needed: bool = False, +) -> str: + """ + Mint a DAG processor token and write it to ``output`` (or ``[dag_processor] api_token_path``). + + The token is written atomically (temp file then ``rename``) so the processor, which re-reads the + file, never observes a half-written token. Re-run before the token expires to rotate it in place. + + :return: the path written. + """ + target = output or conf.get("dag_processor", "api_token_path", fallback=None) + if not target: + raise ValueError( + "No output path for the DAG processor token: pass output or set [dag_processor] api_token_path." + ) + token = mint_dag_processor_token(valid_for=valid_for, make_secret_key_if_needed=make_secret_key_if_needed) + path = Path(target) + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_name(f"{path.name}.tmp") + tmp_path.write_text(token) + os.replace(tmp_path, path) + log.info("Wrote DAG processor API token to %s", path) + return str(path) diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py b/airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py new file mode 100644 index 0000000000000..fb62457db1aaf --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +DAG Processing API (AIP-92). + +A FastAPI sub-app that lets the DAG processor persist parse results without a +direct metadata-DB connection. The processor parses files locally and POSTs the +results here; this app owns the DB writes. +""" + +from __future__ import annotations diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/app.py b/airflow-core/src/airflow/api_fastapi/dag_processing/app.py new file mode 100644 index 0000000000000..689e374948a7f --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/app.py @@ -0,0 +1,322 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DAG Processing API sub-app (AIP-92): persistence endpoints for the DAG processor.""" + +from __future__ import annotations + +import logging +from datetime import timedelta + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from sqlalchemy import select, update +from sqlalchemy.orm import load_only + +from airflow._shared.timezones import timezone +from airflow.api_fastapi.dag_processing.datamodels import ( + BundleStateResponse, + BundleStateUpdateBody, + CallbackClaimBody, + JobCompleteBody, + JobRegisterBody, + ParsingResultBody, + PriorityClaimBody, + ReconcileBody, + StaleDagsBody, +) +from airflow.api_fastapi.dag_processing.security import require_dag_processing_auth +from airflow.dag_processing.bundles.manager import DagBundlesManager +from airflow.dag_processing.collection import update_dag_parsing_results_in_db +from airflow.jobs.job import Job +from airflow.models.asset import remove_references_to_deleted_dags +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.models.dagbundle import DagBundleModel +from airflow.models.dagwarning import DagWarning +from airflow.models.db_callback_request import DbCallbackRequest +from airflow.models.errors import ParseImportError +from airflow.serialization.serialized_objects import LazyDeserializedDAG +from airflow.utils.session import create_session +from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks + +log = logging.getLogger(__name__) + +router = APIRouter() + + +def health() -> dict: + return {"status": "healthy"} + + +@router.post("/parsing-results", status_code=201) +def persist_parsing_results(body: ParsingResultBody) -> dict: + """Persist one file's parse results. Mirrors ``persist_parsing_result``.""" + try: + dags = [LazyDeserializedDAG.model_validate(d) for d in body.serialized_dags] + except Exception as e: + raise HTTPException(status_code=422, detail=f"Invalid serialized_dags payload: {e}") + + import_errors: dict[tuple[str, str], str] = {} + if body.import_errors: + import_errors = {(body.bundle_name, rel): err for rel, err in body.import_errors.items()} + + files_parsed: set[tuple[str, str]] | None = None + if body.relative_fileloc is not None: + files_parsed = {(body.bundle_name, body.relative_fileloc)} + files_parsed.update(import_errors.keys()) + + warnings = [DagWarning(**warn) for warn in (body.warnings or [])] + + with create_session() as session: + update_dag_parsing_results_in_db( + bundle_name=body.bundle_name, + bundle_version=body.bundle_version, + version_data=body.version_data, + dags=dags, + import_errors=import_errors, + parse_duration=body.run_duration, + warnings=set(warnings), + session=session, + files_parsed=files_parsed, + ) + + return {"bundle_name": body.bundle_name, "relative_fileloc": body.relative_fileloc} + + +@router.post("/bundles/{bundle_name}/reconcile") +def reconcile_bundle(bundle_name: str, body: ReconcileBody) -> dict: + """ + Deactivate DAGs/import-errors for files no longer present in the bundle. + + Mirrors ``deactivate_deleted_dags`` + ``clear_orphaned_import_errors``. These run in + separate transactions, and import-error cleanup swallows its own errors, so a cleanup + failure cannot roll back the deactivations (matching the original behaviour). + """ + observed = set(body.observed_filelocs) + + with create_session() as session: + deactivated = DagModel.deactivate_deleted_dags( + bundle_name=bundle_name, + rel_filelocs=observed, + session=session, + ) + if deactivated: + remove_references_to_deleted_dags(session=session) + + try: + with create_session() as session: + errors = session.scalars( + select(ParseImportError) + .where(ParseImportError.bundle_name == bundle_name) + .options(load_only(ParseImportError.filename)) + ) + for error in errors: + if error.filename not in observed: + session.delete(error) + except Exception: + log.exception("Error removing old import errors for bundle %s", bundle_name) + + return {"bundle_name": bundle_name, "deactivated": bool(deactivated)} + + +@router.get("/bundles/{bundle_name}/state", response_model=BundleStateResponse) +def get_bundle_state(bundle_name: str) -> BundleStateResponse: + """Return a bundle's persisted refresh state (last_refreshed + version).""" + with create_session() as session: + row = session.scalar( + select(DagBundleModel) + .where(DagBundleModel.name == bundle_name) + .options(load_only(DagBundleModel.last_refreshed, DagBundleModel.version)) + ) + if row is None: + return BundleStateResponse(found=False) + return BundleStateResponse(found=True, last_refreshed=row.last_refreshed, version=row.version) + + +@router.patch("/bundles/{bundle_name}/state") +def update_bundle_state(bundle_name: str, body: BundleStateUpdateBody) -> dict: + """Persist a bundle's post-refresh state. Updates ``version`` only when provided.""" + values: dict = {"last_refreshed": body.last_refreshed} + if body.version is not None: + values["version"] = body.version + with create_session() as session: + session.execute(update(DagBundleModel).where(DagBundleModel.name == bundle_name).values(**values)) + return {"bundle_name": bundle_name} + + +@router.post("/bundles/sync") +def sync_bundles() -> dict: + """Sync the configured DAG bundles to the metadata database (server-side).""" + DagBundlesManager().sync_bundles_to_db() + return {"synced": True} + + +@router.post("/stale-dags") +def deactivate_stale_dags(body: StaleDagsBody) -> dict: + """ + Deactivate DAGs whose files have not been re-parsed within the stale threshold. + + Mirrors ``DagFileProcessorManager.deactivate_stale_dags`` server-side. + """ + last_parsed = {(e.bundle_name, e.relative_fileloc): e.last_finish_time for e in body.last_parsed} + to_deactivate: set[str] = set() + deactivated = 0 + with create_session() as session: + inactive_bundles = set( + session.scalars(select(DagBundleModel.name).where(DagBundleModel.active.is_(False))).all() + ) + rows = session.execute( + select( + DagModel.dag_id, + DagModel.bundle_name, + DagModel.last_parsed_time, + DagModel.relative_fileloc, + ).where(~DagModel.is_stale) + ) + for row in rows: + if row.bundle_name in inactive_bundles: + to_deactivate.add(row.dag_id) + continue + last_finish_time = last_parsed.get((row.bundle_name, row.relative_fileloc)) + if last_finish_time and ( + row.last_parsed_time + timedelta(seconds=body.stale_dag_threshold) < last_finish_time + ): + to_deactivate.add(row.dag_id) + if to_deactivate: + result = session.execute( + update(DagModel) + .where(DagModel.dag_id.in_(to_deactivate)) + .values(is_stale=True) + .execution_options(synchronize_session="fetch") + ) + deactivated = getattr(result, "rowcount", 0) + return {"deactivated": deactivated} + + +@router.post("/purge-warnings") +def purge_inactive_dag_warnings() -> dict: + """Delete warnings for inactive/stale DAGs (server-side).""" + DagWarning.purge_inactive_dag_warnings() + return {"purged": True} + + +@router.post("/priority-parse-requests/claim") +def claim_priority_parse_requests(body: PriorityClaimBody) -> dict: + """Claim (select + delete, one transaction) priority parse requests for the given bundles.""" + claimed: list[dict] = [] + with create_session() as session: + requests = session.scalars( + select(DagPriorityParsingRequest).where( + DagPriorityParsingRequest.bundle_name.in_(body.bundle_names) + ) + ) + for request in requests: + claimed.append({"bundle_name": request.bundle_name, "relative_fileloc": request.relative_fileloc}) + session.delete(request) + return {"claimed": claimed} + + +@router.post("/callbacks/claim") +def claim_callbacks(body: CallbackClaimBody) -> dict: + """ + Claim callbacks for the given bundles using FOR UPDATE SKIP LOCKED, server-side. + + Mirrors ``DagFileProcessorManager._fetch_callbacks_from_db``. Returns each claimed + callback's raw ``{req_class, req_data}`` payload so the caller can rebuild the typed + ``CallbackRequest`` exactly as ``DbCallbackRequest.get_callback_request`` does. + """ + claimed: list[dict] = [] + with create_session() as session: + with prohibit_commit(session) as guard: + query = with_row_locks( + select(DbCallbackRequest) + .where(DbCallbackRequest.bundle_name.in_(body.bundle_names)) + .order_by(DbCallbackRequest.priority_weight.desc()) + .limit(body.limit), + of=DbCallbackRequest, + session=session, + skip_locked=True, + ) + callbacks = [cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query)] + for callback in callbacks: + claimed.append(callback.data) + session.delete(callback) + guard.commit() + return {"callbacks": claimed} + + +@router.post("/jobs", status_code=201) +def register_job(body: JobRegisterBody) -> dict: + """Register the processor's liveness Job row (server-side) and return its id.""" + job = Job() + job.job_type = body.job_type + # ``Job()`` defaults hostname/unixname/pid to *this* API server's, but the processor runs in a + # different process (and usually host). Record the identity it reported so its health check + # (``airflow jobs check --job-type DagProcessorJob --hostname ``) matches this row. + job.hostname = body.hostname + if body.unixname: + job.unixname = body.unixname + if body.pid is not None: + job.pid = body.pid + with create_session() as session: + job.prepare_for_execution(session=session) + return {"job_id": job.id} + + +@router.post("/jobs/{job_id}/heartbeat") +def job_heartbeat(job_id: int) -> dict: + """Update the processor Job's latest_heartbeat so the health check sees it alive.""" + with create_session() as session: + job = session.get(Job, job_id) + if job is None: + raise HTTPException(status_code=404, detail="Job not found") + job.latest_heartbeat = timezone.utcnow() + session.merge(job) + return {"alive": True} + + +@router.post("/jobs/{job_id}/complete") +def complete_job(job_id: int, body: JobCompleteBody) -> dict: + """Record the processor Job's terminal state and end time.""" + with create_session() as session: + job = session.get(Job, job_id) + if job is not None: + job.end_date = timezone.utcnow() + job.state = body.state + session.merge(job) + return {"completed": True} + + +def create_dag_processing_api_app() -> FastAPI: + """Create the DAG Processing API sub-app (mounted at ``/dag-processing``).""" + app = FastAPI( + title="Airflow DAG Processing API", + description="Persistence endpoints for the DAG processor (AIP-92).", + ) + + @app.exception_handler(Exception) + async def _unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + # Mounted sub-apps build their own middleware stack, so without this the parent + # app never logs unhandled exceptions raised here. + log.exception("Unhandled exception in DAG Processing API: %s %s", request.method, request.url.path) + return JSONResponse(status_code=500, content={"detail": "Internal Server Error"}) + + # /health stays unauthenticated so external readiness probes work; every persistence + # endpoint requires a valid DAG-processor token. + app.add_api_route("/health", health, methods=["GET"]) + app.include_router(router, dependencies=[Depends(require_dag_processing_auth)]) + return app diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py b/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py new file mode 100644 index 0000000000000..494d9e96d6ab7 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/datamodels.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Request/response models for the DAG Processing API (AIP-92).""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel + + +class ParsingResultBody(BaseModel): + """One file's parse output, mirroring ``DagFileProcessorManager.persist_parsing_result``.""" + + bundle_name: str + bundle_version: str | None = None + version_data: dict | None = None + relative_fileloc: str | None = None + run_duration: float | None = None + serialized_dags: list[dict] = [] + import_errors: dict[str, str] | None = None + warnings: list[dict] | None = None + + +class ReconcileBody(BaseModel): + """The full set of source paths currently observed in a bundle, for stale reconciliation.""" + + observed_filelocs: list[str] + + +class BundleStateResponse(BaseModel): + """A bundle's persisted refresh state (``found=False`` when it has no record).""" + + found: bool + last_refreshed: datetime | None = None + version: str | None = None + + +class BundleStateUpdateBody(BaseModel): + """Post-refresh state for a bundle. ``version=None`` leaves the stored version unchanged.""" + + last_refreshed: datetime + version: str | None = None + + +class StaleDagEntry(BaseModel): + """A parsed file's identity and last parse-finish time, for the stale-dag sweep.""" + + bundle_name: str + relative_fileloc: str + last_finish_time: datetime + + +class StaleDagsBody(BaseModel): + """Inputs for the time-based stale-dag sweep.""" + + stale_dag_threshold: int + last_parsed: list[StaleDagEntry] + + +class PriorityClaimBody(BaseModel): + """Bundle names to claim priority parse requests for.""" + + bundle_names: list[str] + + +class CallbackClaimBody(BaseModel): + """Bundle names and per-loop limit for claiming callbacks.""" + + bundle_names: list[str] + limit: int + + +class JobRegisterBody(BaseModel): + """ + Identity of the processor whose liveness Job is being registered. + + ``hostname``/``unixname``/``pid`` are the *processor's* (not this API server's), so the Job row + reflects where the processor runs and ``airflow jobs check --hostname`` matches it. + """ + + job_type: str + hostname: str + unixname: str | None = None + pid: int | None = None + + +class JobCompleteBody(BaseModel): + """Terminal state to record when the processor stops.""" + + state: str diff --git a/airflow-core/src/airflow/api_fastapi/dag_processing/security.py b/airflow-core/src/airflow/api_fastapi/dag_processing/security.py new file mode 100644 index 0000000000000..0449fe9b3cb98 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/dag_processing/security.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Authentication for the DAG Processing API (AIP-92).""" + +from __future__ import annotations + +import functools +import logging + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from airflow.api_fastapi.auth.tokens import JWTValidator, get_sig_validation_args +from airflow.configuration import conf + +log = logging.getLogger(__name__) + + +@functools.lru_cache(maxsize=1) +def _token_validator() -> JWTValidator: + """ + Build the JWT validator for the DAG Processing API. + + ``get_sig_validation_args`` selects symmetric (deployment ``jwt_secret``) or asymmetric/JWKS + validation (``[api_auth] trusted_jwks_url``) the same way the Execution API does, so a + multi-tenant deployment validates externally-issued tokens without any change here. + """ + required_claims = frozenset({"aud", "exp", "iat"}) + issuer = conf.get("api_auth", "jwt_issuer", fallback=None) + if issuer: + required_claims = required_claims | {"iss"} + return JWTValidator( + required_claims=required_claims, + issuer=issuer, + audience=conf.get_mandatory_list_value("dag_processor", "jwt_audience"), + **get_sig_validation_args(make_secret_key_if_needed=False), + ) + + +_bearer = HTTPBearer(auto_error=False) + + +def require_dag_processing_auth( + creds: HTTPAuthorizationCredentials | None = Depends(_bearer), +) -> None: + """ + Reject DAG Processing API calls that lack a valid DAG-processor token. + + The DAG processor self-signs a token for the configured ``[dag_processor] jwt_audience`` with + the deployment signing key; this validates the signature, audience, and issuer so the + persistence endpoints are not callable unauthenticated. + """ + if creds is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + try: + _token_validator().validated_claims(creds.credentials) + except Exception: + log.warning("Invalid DAG Processing API token", exc_info=True) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid auth token") diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index c43d79d5bda49..a09da021f459f 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -218,6 +218,12 @@ def string_lower_type(val): type=str, default="[CWD]" if BUILD_DOCS else os.getcwd(), ) +ARG_DAG_PROCESSOR_TOKEN_OUTPUT = Arg( + ("--output",), + help="File to write the token to. Defaults to the `[dag_processor] api_token_path` config.", + type=str, + default=None, +) ARG_PID = Arg(("--pid",), help="PID file location", nargs="?") ARG_DAEMON = Arg( ("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true" @@ -752,7 +758,7 @@ def string_lower_type(val): ) ARG_API_SERVER_APPS = Arg( ("--apps",), - help="Applications to run (comma-separated). Default is all. Options: core, execution, all", + help="Applications to run (comma-separated). Default is all. Options: core, execution, dag-processing, all", default="all", ) ARG_API_SERVER_ALLOW_PROXY_FORWARDING = Arg( @@ -2249,6 +2255,18 @@ class GroupCommand(NamedTuple): ARG_DEV, ), ), + ActionCommand( + name="provision-dag-processor-token", + help=( + "Mint the DAG processor's API token and write it to a file. Run by a trusted " + "deployment step (which holds the signing key), not by the processor itself." + ), + func=lazy_load_command("airflow.cli.commands.dag_processor_command.provision_dag_processor_token"), + args=( + ARG_DAG_PROCESSOR_TOKEN_OUTPUT, + ARG_VERBOSE, + ), + ), ActionCommand( name="version", help="Show the version", diff --git a/airflow-core/src/airflow/cli/commands/dag_processor_command.py b/airflow-core/src/airflow/cli/commands/dag_processor_command.py index f4c303c278dbb..c8c0a67005027 100644 --- a/airflow-core/src/airflow/cli/commands/dag_processor_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_processor_command.py @@ -19,12 +19,14 @@ from __future__ import annotations import logging +import time from typing import Any +from airflow.api_fastapi.auth.dag_processor_token import provision_dag_processor_token_file from airflow.cli.commands.daemon_utils import run_command_with_daemon_option from airflow.dag_processing.manager import DagFileProcessorManager from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner -from airflow.jobs.job import Job, run_job +from airflow.jobs.job import Job from airflow.utils import cli as cli_utils from airflow.utils.memray_utils import MemrayTraceComponents, enable_memray_trace from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -45,6 +47,67 @@ def _create_dag_processor_job_runner(args: Any) -> DagProcessorJobRunner: ) +def _run_dag_processor_job(job_runner: DagProcessorJobRunner) -> None: + """ + Run the job, registering the liveness Job through the DAG Processing API. + + The DAG processor holds no metadata-DB connection, so its Job row is registered, + heartbeated, and completed through the DAG Processing API instead of writing the + database directly. The first call waits for the API server to be ready (the processor and + server may start concurrently), the heartbeat is throttled to the job heartrate and never + crashes the loop, and a failure to complete the Job does not mask a parsing-loop error. + """ + client = job_runner.processor._dag_processing_client + client.wait_until_ready() + job_id = client.register_job(job_runner.job_type) + + heartrate = getattr(job_runner.job, "heartrate", 5.0) or 5.0 + last_heartbeat = 0.0 + + def _heartbeat() -> None: + nonlocal last_heartbeat + now = time.monotonic() + if now - last_heartbeat < heartrate: + return + try: + client.job_heartbeat(job_id) + last_heartbeat = now + except Exception: + # Liveness update failure must not crash the processor; the next loop retries. + log.warning("DAG Processing API heartbeat failed; retrying next loop", exc_info=True) + + job_runner.processor.heartbeat = _heartbeat + state = "success" + try: + job_runner._execute() + except SystemExit: + pass + except Exception: + state = "failed" + raise + finally: + try: + client.complete_job(job_id, state=state) + except Exception: + # Don't let a completion failure mask the original parsing-loop exception. + log.warning("Failed to mark DAG processor Job %s as %s", job_id, state, exc_info=True) + + +@cli_utils.action_cli +@providers_configuration_loaded +def provision_dag_processor_token(args): + """ + Mint the DAG processor's API token and write it to a file (trusted bootstrap step). + + Run by a trusted component (a deployment init step, not the processor itself), which holds the + signing key. Writes to ``--output`` or, by default, ``[dag_processor] api_token_path``. The DAG + processor then only reads that file. Re-run before ``[dag_processor] jwt_expiration_time`` + elapses to rotate the token in place. + """ + path = provision_dag_processor_token_file(getattr(args, "output", None)) + print(f"DAG processor API token written to {path}") + + @enable_memray_trace(component=MemrayTraceComponents.dag_processor) @cli_utils.action_cli @providers_configuration_loaded @@ -56,7 +119,7 @@ def dag_processor(args): from airflow.cli.hot_reload import run_with_reloader run_with_reloader( - lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute), + lambda: _run_dag_processor_job(job_runner), process_name="dag-processor", ) return @@ -64,6 +127,6 @@ def dag_processor(args): run_command_with_daemon_option( args=args, process_name="dag-processor", - callback=lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute), + callback=lambda: _run_dag_processor_job(job_runner), should_setup_logging=True, ) diff --git a/airflow-core/src/airflow/cli/commands/standalone_command.py b/airflow-core/src/airflow/cli/commands/standalone_command.py index 2e94637e763c4..b6403c84b85d4 100644 --- a/airflow-core/src/airflow/cli/commands/standalone_command.py +++ b/airflow-core/src/airflow/cli/commands/standalone_command.py @@ -20,6 +20,7 @@ import os import socket import subprocess +import tempfile import threading import time from collections import deque @@ -28,6 +29,8 @@ from termcolor import colored from airflow.api_fastapi.app import create_auth_manager +from airflow.api_fastapi.auth.dag_processor_token import provision_dag_processor_token_file +from airflow.api_fastapi.auth.tokens import get_signing_key from airflow.configuration import conf from airflow.executors import executor_constants from airflow.executors.executor_loader import ExecutorLoader @@ -189,8 +192,37 @@ def calculate_env(self): env["AIRFLOW__CORE__AUTH_MANAGER"] = simple_auth_manager_classpath os.environ["AIRFLOW__CORE__AUTH_MANAGER"] = simple_auth_manager_classpath # also in this process! + self._provision_dag_processor_token(env) + return env + def _provision_dag_processor_token(self, env: dict[str, str]) -> None: + """ + Mint the API token the DAG processor presents and point it at a file via env. + + The DAG processor parses user code, so it must not hold the signing key or mint its own + token. Standalone runs in a trusted context (it already holds the signing key the + scheduler uses to mint task tokens), so it mints the processor's token here and provisions + it through ``[dag_processor] api_token_path``. The token carries both the Execution and DAG + Processing audiences (the processor calls both APIs) and a sentinel subject, since the + Execution API is task-instance scoped. + """ + try: + # Materialise a signing key shared by every standalone subprocess, so the api-server + # validates the token with the same key it is minted with here. + secret = get_signing_key("api_auth", "jwt_secret", make_secret_key_if_needed=True) + env["AIRFLOW__API_AUTH__JWT_SECRET"] = secret + os.environ["AIRFLOW__API_AUTH__JWT_SECRET"] = secret + + # Standalone is the trusted launcher, so it mints the processor's token here (the + # processor never mints its own); the same minting is used by real deployments. + fd, path = tempfile.mkstemp(prefix="airflow-standalone-", suffix=".token") + os.close(fd) + provision_dag_processor_token_file(path, make_secret_key_if_needed=True) + env["AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH"] = path + except Exception as e: + self.print_output("standalone", f"Could not provision the DAG processor API token: {e}") + def find_user_info(self): if conf.get("core", "simple_auth_manager_all_admins").lower() == "true": # If we have no auth anyways, no need to print or do anything diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 74332719d0c55..26a676a1d1262 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -511,6 +511,17 @@ core: type: string example: ~ default: ~ + dag_processing_api_server_url: + description: | + The url of the DAG processing api server. The DAG processor never accesses the metadata + database directly; it persists parse results and reads metadata through this api server. + Defaults to ``{BASE_URL}/dag-processing`` (a sibling of ``execution_api_server_url``), + where ``{BASE_URL}`` is the base url of the API server. Set this to point the DAG + processor at a different host. + version_added: 3.3.0 + type: string + example: "http://localhost:8080/dag-processing" + default: ~ multi_team: description: | Whether to run the Airflow environment in multi-team mode. @@ -2974,6 +2985,40 @@ dag_processor: example: "/tmp/some-place" default: ~ + jwt_audience: + version_added: 3.3.0 + description: | + The audience claim the DAG Processing API validates on tokens the DAG processor presents. + Can be a single value or a comma-separated string (all are accepted at validation time). + The token itself is minted by the deployment, not the processor; see ``api_token_path``. + example: "my-unique-airflow-id" + default: "urn:airflow.apache.org:dag-processing" + type: string + + api_token_path: + version_added: 3.3.0 + description: | + Path to a file containing the bearer token the DAG processor presents to the API server + (for both the DAG Processing and Execution APIs). The DAG processor parses user code, so + it must not hold the signing key or mint its own token; the deployment (or control plane) + provisions this token and writes it to the file. The processor re-reads this file as it is + rotated, so a refreshed token is picked up without a restart. If unset, the processor sends + no token. + example: "/var/run/airflow/dag-processor-token" + default: ~ + type: string + + jwt_expiration_time: + version_added: 3.3.0 + description: | + Lifetime, in seconds, of the token a trusted component (the deployment's token-provisioning + step, or ``airflow standalone``) mints for the DAG processor and writes to ``api_token_path``. + The provisioning step should re-mint before this elapses so the long-running processor keeps + a valid token. The processor itself never mints a token; this only configures the minter. + default: "86400" + type: integer + example: ~ + dag_bundle_config_list: description: | List of backend configs. Must supply name, classpath, and kwargs for each backend. diff --git a/airflow-core/src/airflow/dag_processing/api_client.py b/airflow-core/src/airflow/dag_processing/api_client.py new file mode 100644 index 0000000000000..295b035409b4f --- /dev/null +++ b/airflow-core/src/airflow/dag_processing/api_client.py @@ -0,0 +1,293 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""HTTP client used by the DAG processor to persist parse results via the API (AIP-92).""" + +from __future__ import annotations + +import logging +import os +import time +from datetime import datetime +from getpass import getuser +from importlib import import_module +from typing import TYPE_CHECKING, Any +from urllib.parse import quote + +import httpx + +from airflow.utils.net import get_hostname + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from airflow.dag_processing.processor import DagFileParsingResult + +log = logging.getLogger(__name__) + +# Connection-level retries are safe for every request (the request never reached the server), +# so they are applied uniformly by the transport. Application-level retries (5xx / read timeout) +# are only added for idempotent calls; claim/delete calls must not be retried after the request +# may have been processed server-side, or callbacks/priority-requests could be silently lost. +_CONNECT_RETRIES = 3 +_TRANSIENT_RETRIES = 3 +_TRANSIENT_BACKOFF = 0.5 + + +def _parse_iso_datetime(value: str) -> datetime: + """ + Parse an ISO-8601 timestamp from the API server, tolerating a trailing ``Z``. + + The server emits UTC timestamps with a ``Z`` designator (e.g. ``2026-06-03T03:52:13.938753Z``), + but ``datetime.fromisoformat`` only accepts ``Z`` on Python 3.11+. Normalise it to an explicit + offset so parsing works on the oldest supported runtime (3.10). + """ + if value.endswith("Z"): + value = f"{value[:-1]}+00:00" + return datetime.fromisoformat(value) + + +# How long a token read from the provisioned file is reused before re-reading. Keeps the tight +# parse loop from stat-ing the file on every request while still picking up a rotated token quickly. +_TOKEN_CACHE_TTL = 30.0 + + +class _CallableBearerAuth(httpx.Auth): + """ + Attach a bearer token read from a callable, re-reading it as the deployment rotates it. + + The DAG processor never holds the signing key; a trusted component provisions its token to a + file (``[dag_processor] api_token_path``) and rotates it there before expiry. Reading the token + per-request (rather than baking it into the client once) lets a long-running processor pick up a + rotated token without restarting. The value is cached for ``_TOKEN_CACHE_TTL`` to avoid a file + read on every call; a ``401`` forces an immediate re-read and one retry, so a rotation that lands + mid-window is honoured without waiting out the cache. + """ + + def __init__(self, token_getter: Callable[[], str | None], *, cache_ttl: float = _TOKEN_CACHE_TTL): + self._token_getter = token_getter + self._cache_ttl = cache_ttl + self._cached: str | None = None + self._cached_at = 0.0 + + def _token(self, *, refresh: bool = False) -> str | None: + now = time.monotonic() + if refresh or self._cached is None or now - self._cached_at >= self._cache_ttl: + self._cached = self._token_getter() + self._cached_at = now + return self._cached + + def auth_flow(self, request: httpx.Request): + token = self._token() + if token: + request.headers["Authorization"] = f"Bearer {token}" + response = yield request + if response.status_code == 401: + # The token may have rotated on disk since the cached read; re-read and retry once. + fresh = self._token(refresh=True) + if fresh and fresh != token: + request.headers["Authorization"] = f"Bearer {fresh}" + yield request + + +class DagProcessingApiClient: + """ + Forward DAG-processor persistence to the ``/dag-processing`` API sub-app. + + Replaces the DAG processor manager's direct metadata-DB writes: parse results + and stale reconciliation are sent over HTTP, so the processor process needs no + DB credentials for persistence. A single :class:`httpx.Client` is reused so + connections are pooled across the manager's parse loop. + """ + + def __init__( + self, + base_url: str, + *, + token_getter: Callable[[], str | None] | None = None, + timeout: float = 30.0, + ) -> None: + self._base_url = base_url.rstrip("/") + # The token is read from ``token_getter`` per request (see _CallableBearerAuth), not baked + # in, so a token rotated on disk by the deployment is picked up without a restart. + auth = _CallableBearerAuth(token_getter) if token_getter is not None else None + # ``retries`` retries connection failures (request never sent) at the transport level. + self._client = httpx.Client( + auth=auth, + timeout=timeout, + transport=httpx.HTTPTransport(retries=_CONNECT_RETRIES), + ) + + def close(self) -> None: + self._client.close() + + def wait_until_ready(self, *, timeout: float = 60.0) -> None: + """ + Block until the DAG Processing API answers ``/health`` or ``timeout`` elapses. + + The DAG processor and API server may start concurrently (e.g. ``airflow standalone``); + this avoids crashing on a cold-start race before the server has bound its socket. + """ + deadline = time.monotonic() + timeout + delay = 0.5 + last_exc: Exception | None = None + while time.monotonic() < deadline: + try: + resp = self._client.get(f"{self._base_url}/health") + if resp.status_code < 500: + return + except httpx.HTTPError as exc: + last_exc = exc + time.sleep(delay) + delay = min(delay * 2, 5.0) + log.warning( + "DAG Processing API at %s not ready after %ss", self._base_url, timeout, exc_info=last_exc + ) + + def _send( + self, + method: str, + path: str, + *, + retry_transient: bool = True, + **kwargs: Any, + ) -> httpx.Response: + """ + Send a request, retrying transient failures. + + Connection failures are retried by the transport. When ``retry_transient`` is set + (idempotent calls only), read timeouts and 5xx responses are additionally retried with + backoff. Non-idempotent claim/delete calls pass ``retry_transient=False`` so a response + lost after the server processed the request does not cause a silent retry that drops rows. + """ + url = f"{self._base_url}{path}" + attempts = _TRANSIENT_RETRIES if retry_transient else 1 + for attempt in range(attempts): + try: + resp = self._client.request(method, url, **kwargs) + except httpx.TransportError: + if not retry_transient or attempt == attempts - 1: + raise + else: + if not (retry_transient and resp.status_code >= 500 and attempt < attempts - 1): + resp.raise_for_status() + return resp + time.sleep(_TRANSIENT_BACKOFF * (2**attempt)) + # Unreachable: the loop either returns or raises on the final attempt. + raise RuntimeError("unreachable") + + def persist_parsing_result( + self, + *, + bundle_name: str, + bundle_version: str | None, + version_data: dict | None, + parsing_result: DagFileParsingResult, + run_duration: float, + relative_fileloc: str | None, + ) -> None: + payload = { + "bundle_name": bundle_name, + "bundle_version": bundle_version, + "version_data": version_data, + "relative_fileloc": relative_fileloc, + "run_duration": run_duration, + "serialized_dags": [d.model_dump(mode="json") for d in parsing_result.serialized_dags], + "import_errors": parsing_result.import_errors, + "warnings": parsing_result.warnings or [], + } + self._send("POST", "/parsing-results", json=payload) + log.debug("Persisted parse result via API: bundle=%s file=%s", bundle_name, relative_fileloc) + + def reconcile(self, *, bundle_name: str, observed_filelocs: Iterable[str]) -> None: + payload = {"observed_filelocs": sorted(observed_filelocs)} + self._send("POST", f"/bundles/{quote(bundle_name, safe='')}/reconcile", json=payload) + log.debug("Reconciled bundle via API: bundle=%s", bundle_name) + + def get_bundle_state(self, bundle_name: str) -> dict | None: + """Return ``{last_refreshed, version}`` for a bundle, or ``None`` if it has no record.""" + resp = self._send("GET", f"/bundles/{quote(bundle_name, safe='')}/state") + data = resp.json() + if not data.get("found"): + return None + last_refreshed = data.get("last_refreshed") + return { + "last_refreshed": _parse_iso_datetime(last_refreshed) if last_refreshed else None, + "version": data.get("version"), + } + + def update_bundle_state(self, bundle_name: str, *, last_refreshed: datetime, version: str | None) -> None: + payload = {"last_refreshed": last_refreshed.isoformat(), "version": version} + self._send("PATCH", f"/bundles/{quote(bundle_name, safe='')}/state", json=payload) + + def sync_bundles(self) -> None: + self._send("POST", "/bundles/sync") + + def deactivate_stale_dags(self, *, stale_dag_threshold: int, last_parsed: list[dict]) -> None: + payload = {"stale_dag_threshold": stale_dag_threshold, "last_parsed": last_parsed} + self._send("POST", "/stale-dags", json=payload) + + def purge_inactive_dag_warnings(self) -> None: + self._send("POST", "/purge-warnings") + + def claim_priority_files(self, bundle_names: list[str]) -> list[dict]: + """Claim (select + delete) priority parse requests for the given bundles.""" + # Not idempotent (claim + delete): only the transport's connection-failure retry applies. + resp = self._send( + "POST", + "/priority-parse-requests/claim", + retry_transient=False, + json={"bundle_names": bundle_names}, + ) + return resp.json().get("claimed", []) + + def fetch_callbacks(self, *, bundle_names: list[str], limit: int) -> list: + """Claim callbacks via the API and rebuild typed ``CallbackRequest`` objects.""" + # Not idempotent (claim + delete): only the transport's connection-failure retry applies. + resp = self._send( + "POST", + "/callbacks/claim", + retry_transient=False, + json={"bundle_names": bundle_names, "limit": limit}, + ) + module = import_module("airflow.callbacks.callback_requests") + callbacks = [] + for data in resp.json().get("callbacks", []): + callback_request_class = getattr(module, data["req_class"]) + callbacks.append(callback_request_class.from_json(data["req_data"])) + return callbacks + + def register_job(self, job_type: str) -> int: + """Register the processor's liveness Job row server-side and return its id.""" + # Report this processor's own identity so the Job row reflects where the processor runs, + # not the API server; the health check matches on hostname (see the /jobs endpoint). + resp = self._send( + "POST", + "/jobs", + json={ + "job_type": job_type, + "hostname": get_hostname(), + "unixname": getuser(), + "pid": os.getpid(), + }, + ) + return resp.json()["job_id"] + + def job_heartbeat(self, job_id: int) -> None: + self._send("POST", f"/jobs/{job_id}/heartbeat") + + def complete_job(self, job_id: int, *, state: str) -> None: + self._send("POST", f"/jobs/{job_id}/complete", json={"state": state}) diff --git a/airflow-core/src/airflow/dag_processing/manager.py b/airflow-core/src/airflow/dag_processing/manager.py index fc00b6730c8ea..31ad8bc1a60af 100644 --- a/airflow-core/src/airflow/dag_processing/manager.py +++ b/airflow-core/src/airflow/dag_processing/manager.py @@ -32,40 +32,35 @@ import zipfile from collections import OrderedDict, defaultdict from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import datetime from operator import attrgetter, itemgetter from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast import attrs import structlog -from sqlalchemy import select, update -from sqlalchemy.orm import load_only from tabulate import tabulate from uuid6 import uuid7 from airflow._shared.observability.metrics import stats from airflow._shared.observability.metrics.stats import normalize_name_for_stats from airflow._shared.timezones import timezone -from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.configuration import conf +from airflow.dag_processing.api_client import DagProcessingApiClient from airflow.dag_processing.bundles.base import ( BundleUsageTrackingManager, unpack_bundle_version, ) from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.dag_processing.collection import update_dag_parsing_results_in_db from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess from airflow.exceptions import AirflowException -from airflow.models.asset import remove_references_to_deleted_dags -from airflow.models.dag import DagModel -from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.models.dagbundle import DagBundleModel -from airflow.models.dagwarning import DagWarning -from airflow.models.db_callback_request import DbCallbackRequest -from airflow.models.errors import ParseImportError +from airflow.executors.base_executor import get_execution_api_server_url from airflow.observability.metrics import stats_utils from airflow.sdk import SecretCache +from airflow.sdk.api.client import Client +from airflow.sdk.execution_time import task_runner +from airflow.sdk.execution_time.comms import GetConnection, GetVariable +from airflow.sdk.execution_time.request_handlers import handle_get_connection, handle_get_variable from airflow.sdk.log import init_log_file, logging_processors from airflow.typing_compat import assert_never from airflow.utils.file import list_py_file_paths, might_contain_dag @@ -74,20 +69,16 @@ from airflow.utils.process_utils import ( kill_child_processes_by_pids, ) -from airflow.utils.retries import retry_db_transaction -from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: from collections.abc import Callable, Iterable, Iterator, Sequence from socket import socket - from sqlalchemy.orm import Session - from sqlalchemy.sql import Select - from airflow.callbacks.callback_requests import CallbackRequest from airflow.dag_processing.bundles.base import BaseDagBundle - from airflow.sdk.api.client import Client + + +log = logging.getLogger(__name__) class DagParsingStat(NamedTuple): @@ -149,6 +140,74 @@ def _config_get_factory(section: str, key: str): return functools.partial(conf.get, section, key) +def _dag_processing_api_server_url() -> str: + """ + Resolve the DAG Processing API URL the processor persists through. + + Defaults to the ``/dag-processing`` mount on the configured API server (a sibling of the + ``/execution`` mount), so a standard deployment that already runs the API server needs no + extra configuration. Set ``[core] dag_processing_api_server_url`` to point at a different + host. + + The default is derived from the resolved Execution API URL (``execution_api_server_url``), + not just ``[api] base_url``, so a deployment that points the processor's Execution API at an + internal service URL keeps both sibling clients on the same host. + """ + explicit = conf.get("core", "dag_processing_api_server_url", fallback=None) + if explicit: + return explicit + execution_url = get_execution_api_server_url().rstrip("/") + if execution_url.endswith("/execution"): + return f"{execution_url[: -len('/execution')]}/dag-processing" + return f"{execution_url}/dag-processing" + + +def _api_token() -> str | None: + """ + Return the token the DAG processor presents to the API server, or ``None``. + + The DAG processor parses (and forks) user code, so it must not hold the deployment signing + key or be able to mint tokens. It only *carries* a token provisioned by a trusted component: + read from the file at ``[dag_processor] api_token_path`` (written by the deployment / control + plane). The same token authenticates both the DAG Processing and Execution API clients. The + issuance of that token is intentionally left to the deployment (the AIP-92 non-task principal + question); the processor never signs one itself. + """ + token_path = conf.get("dag_processor", "api_token_path", fallback=None) + if not token_path: + return None + try: + return Path(token_path).read_text().strip() or None + except OSError: + log.warning("Could not read the DAG processor API token from %s", token_path) + return None + + +class _DagProcessorSecretsComms: + """ + Minimal ``SUPERVISOR_COMMS`` for the DAG processor process. + + Bundle initialization (e.g. ``GitDagBundle`` resolving its git connection) runs in the + manager process, which holds no metadata-DB connection. Installing this as + ``task_runner.SUPERVISOR_COMMS`` makes ``ensure_secrets_backend_loaded()`` select + ``ExecutionAPISecretsBackend``; this shim then resolves the connection/variable lookups it + sends through the manager's remote Execution API client -- the same path the worker and + triggerer use -- so DB-stored bundle credentials resolve without direct DB access. + """ + + def __init__(self, client: Client) -> None: + self._client = client + + def send(self, msg: Any, **kwargs: Any) -> Any: + if isinstance(msg, GetConnection): + return handle_get_connection(self._client, msg)[0] + if isinstance(msg, GetVariable): + return handle_get_variable(self._client, msg)[0] + # Other messages (e.g. MaskSecret emitted while masking a fetched secret) do not apply to + # the manager process; ignore them. + return None + + def _resolve_path(instance: Any, attribute: attrs.Attribute, val: str | os.PathLike[str] | None): if val is not None: val = Path(val).resolve() @@ -275,8 +334,12 @@ class DagFileProcessorManager(LoggingMixin): factory=_config_get_factory("dag_processor", "file_parsing_sort_mode") ) - _api_server: InProcessExecutionAPI = attrs.field(init=False, factory=InProcessExecutionAPI) - """API server to interact with Metadata DB""" + _dag_processing_client: DagProcessingApiClient = attrs.field( + init=False, + factory=lambda: DagProcessingApiClient(_dag_processing_api_server_url(), token_getter=_api_token), + ) + """Client for the DAG Processing API. The DAG processor never reads or writes the metadata + database directly; all persistence and metadata reads are routed through the API server.""" def register_exit_signals(self): """Register signals that stop child processes.""" @@ -295,8 +358,8 @@ def _exit_gracefully(self, signum, frame): sys.exit(os.EX_OK) def sync_bundles(self) -> None: - """Sync configured DAG bundles to the metadata database.""" - DagBundlesManager().sync_bundles_to_db() + """Sync configured DAG bundles to the metadata database via the DAG Processing API.""" + self._dag_processing_client.sync_bundles() def get_all_bundles(self) -> list[BaseDagBundle]: """Return configured DAG bundles filtered by ``bundle_names_to_parse`` if provided.""" @@ -317,35 +380,44 @@ def run(self): def before_run(self) -> None: """Set up state required before the parsing loop starts. Default implementation; override to customize.""" - self.prepare_server_process_context() self.prepare_process_context() self.register_exit_signals() self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) + self._setup_secrets_comms() self.prepare_bundles() self._symlink_latest_log_directory() # To prevent COW in forked process parsing dag file gc.freeze() def after_run(self) -> None: - """Tear down state after the parsing loop exits. Default no-op; override to customize.""" + """Tear down state after the parsing loop exits.""" + # Drop the secrets shim installed in before_run so it does not outlive the loop. + if isinstance(getattr(task_runner, "SUPERVISOR_COMMS", None), _DagProcessorSecretsComms): + del task_runner.SUPERVISOR_COMMS - def prepare_server_process_context(self) -> None: + def _setup_secrets_comms(self) -> None: """ - Mark this process as running in "server" context so MetastoreBackend is available. + Route the manager's own connection/variable lookups through the remote Execution API. - Override to a no-op in subclasses that do not require direct DB access (e.g. API-backed - deployments under AIP-92). + Bundle initialization resolves credentials here (see :class:`_DagProcessorSecretsComms`). + Child parser processes reset ``SUPERVISOR_COMMS`` to their own comms in + ``_parse_file_entrypoint()``, so this only affects the manager process. """ - # TODO: Temporary until AIP-92 removes DB access from DagProcessorManager. - # The manager needs MetastoreBackend to retrieve connections from the database - # during bundle initialization (e.g., GitDagBundle.__init__ → GitHook needs git credentials). - # This marks the manager as "server" context so ensure_secrets_backend_loaded() provides - # MetastoreBackend instead of falling back to EnvironmentVariablesBackend only. - # Child parser processes explicitly override this by setting _AIRFLOW_PROCESS_CONTEXT=client - # in _parse_file_entrypoint() to prevent inheriting server privileges. - # Related: https://github.com/apache/airflow/pull/57459 - os.environ["_AIRFLOW_PROCESS_CONTEXT"] = "server" + try: + comms = _DagProcessorSecretsComms(self.client) + except Exception: + # Without a signing key the processor cannot authenticate to the Execution API; + # skip the shim so bundle init still resolves env/external-backend credentials + # instead of crashing. (DB-stored bundle credentials will not resolve in this case.) + self.log.warning( + "Could not initialize the Execution API client for bundle credentials; " + "bundle-init connections will resolve only from non-database secrets backends", + exc_info=True, + ) + return + # Duck-typed comms: the shim implements the .send() interface the secrets backend uses. + task_runner.SUPERVISOR_COMMS = comms # type: ignore[assignment] def prepare_process_context(self) -> None: """Initialize transport-neutral process state (selector, stats) before the parsing loop starts.""" @@ -406,62 +478,24 @@ def cleanup_stale_bundle_versions(self) -> None: """Clean up stale DAG bundle version usage records.""" BundleUsageTrackingManager().remove_stale_bundle_versions() - @provide_session - def deactivate_stale_dags( - self, - last_parsed: dict[DagFileInfo, datetime | None], - *, - session: Session = NEW_SESSION, - ): - """Detect and deactivate DAGs which are no longer present in files.""" - to_deactivate = set() - inactive_bundles = set( - session.scalars(select(DagBundleModel.name).where(DagBundleModel.active.is_(False))).all() - ) - query = select( - DagModel.dag_id, - DagModel.bundle_name, - DagModel.fileloc, - DagModel.last_parsed_time, - DagModel.relative_fileloc, - ).where(~DagModel.is_stale) - dags_parsed = session.execute(query) - - for dag in dags_parsed: - # Dags whose bundle has been removed from config (bundle no longer active) are stale — - # the processor has stopped parsing their files, so the time-based check below would never fire. - if dag.bundle_name in inactive_bundles: - self.log.info( - "Deactivating Dag %s. Its bundle %s is no longer active.", - dag.dag_id, - dag.bundle_name, - ) - to_deactivate.add(dag.dag_id) - continue - # When the Dag's last_parsed_time is more than the stale_dag_threshold older than the - # Dag file's last_finish_time, the Dag is considered stale as has apparently been removed from the file, - # This is especially relevant for Dag files that generate Dags in a dynamic manner. - file_info = DagFileInfo(rel_path=Path(dag.relative_fileloc), bundle_name=dag.bundle_name) - if last_finish_time := last_parsed.get(file_info, None): - if dag.last_parsed_time + timedelta(seconds=self.stale_dag_threshold) < last_finish_time: - self.log.info( - "Deactivating stale DAG %s. Not parsed for %s seconds (last parsed: %s).", - dag.dag_id, - int((last_finish_time - dag.last_parsed_time).total_seconds()), - dag.last_parsed_time, - ) - to_deactivate.add(dag.dag_id) - - if to_deactivate: - deactivated_dagmodel = session.execute( - update(DagModel) - .where(DagModel.dag_id.in_(to_deactivate)) - .values(is_stale=True) - .execution_options(synchronize_session="fetch") + def deactivate_stale_dags(self, last_parsed: dict[DagFileInfo, datetime | None]) -> None: + """Detect and deactivate DAGs which are no longer present in files, via the DAG Processing API.""" + entries = [ + { + "bundle_name": file_info.bundle_name, + "relative_fileloc": str(file_info.rel_path), + "last_finish_time": last_finish_time.isoformat(), + } + for file_info, last_finish_time in last_parsed.items() + if last_finish_time is not None + ] + try: + self._dag_processing_client.deactivate_stale_dags( + stale_dag_threshold=int(self.stale_dag_threshold), last_parsed=entries ) - deactivated = getattr(deactivated_dagmodel, "rowcount", 0) - if deactivated: - self.log.info("Deactivated %i DAGs which are no longer present in file.", deactivated) + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error deactivating stale DAGs via the DAG Processing API") def _run_parsing_loop(self): # initialize cache to mutualize calls to Variable.get in DAGs @@ -551,12 +585,23 @@ def _queue_requested_files_for_parsing(self) -> None: self.log.info("Bundles being force refreshed: %s", ", ".join(self._force_refresh_bundles)) def claim_priority_files(self) -> list[DagFileInfo]: - """ - Fetch and claim files requested for priority parsing. - - Default implementation reads from the metadata DB; override to source requests from an API. - """ - return self._claim_priority_files() + """Fetch and claim files requested for priority parsing, via the DAG Processing API.""" + bundles = {bundle.name: bundle for bundle in self._dag_bundles} + try: + claimed = self._dag_processing_client.claim_priority_files(list(bundles)) + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error claiming priority parse requests via the DAG Processing API") + return [] + return [ + DagFileInfo( + rel_path=Path(entry["relative_fileloc"]), + bundle_name=entry["bundle_name"], + bundle_path=bundles[entry["bundle_name"]].path, + ) + for entry in claimed + if entry["bundle_name"] in bundles + ] def request_bundle_refresh(self, bundle_names: str | Iterable[str]) -> None: """ @@ -587,66 +632,17 @@ def should_skip_refresh( and bundle.name not in self._force_refresh_bundles ) - @provide_session - def _claim_priority_files(self, *, session: Session = NEW_SESSION) -> list[DagFileInfo]: - """Fetch priority parsing requests from the metadata database.""" - files: list[DagFileInfo] = [] - bundles = {b.name: b for b in self._dag_bundles} - requests = session.scalars( - select(DagPriorityParsingRequest).where(DagPriorityParsingRequest.bundle_name.in_(bundles.keys())) - ) - for request in requests: - bundle = bundles[request.bundle_name] - files.append( - DagFileInfo( - rel_path=Path(request.relative_fileloc), bundle_name=bundle.name, bundle_path=bundle.path - ) - ) - session.delete(request) - return files - def fetch_callbacks(self) -> list[CallbackRequest]: - """ - Fetch and claim callbacks for this manager's bundles. - - Default implementation reads from the metadata DB; override to source callbacks from an API. - """ - return self._fetch_callbacks_from_db() - - @provide_session - @retry_db_transaction - def _fetch_callbacks_from_db( - self, - *, - session: Session = NEW_SESSION, - ) -> list[CallbackRequest]: - """Fetch callbacks from database and add them to the internal queue for execution.""" - self.log.debug("Fetching callbacks from the database.") - - callback_queue: list[CallbackRequest] = [] - with prohibit_commit(session) as guard: - bundle_names = [bundle.name for bundle in self._dag_bundles] - query: Select[tuple[DbCallbackRequest]] = with_row_locks( - select(DbCallbackRequest) - .where(DbCallbackRequest.bundle_name.in_(bundle_names)) - .order_by(DbCallbackRequest.priority_weight.desc()) - .limit(self.max_callbacks_per_loop), - of=DbCallbackRequest, - session=session, - skip_locked=True, + """Fetch and claim callbacks for this manager's bundles, via the DAG Processing API.""" + try: + return self._dag_processing_client.fetch_callbacks( + bundle_names=[bundle.name for bundle in self._dag_bundles], + limit=self.max_callbacks_per_loop, ) - callbacks: Sequence[DbCallbackRequest] = [ - cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query) - ] - for callback in callbacks: - req = callback.get_callback_request() - try: - callback_queue.append(req) - session.delete(callback) - except Exception as e: - self.log.warning("Error adding callback for execution: %s, %s", callback, e) - guard.commit() - return callback_queue + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error fetching callbacks via the DAG Processing API") + return [] def prepare_callback_bundle(self, request: CallbackRequest) -> BaseDagBundle | None: """ @@ -689,51 +685,35 @@ def _add_callback_to_queue(self, request: CallbackRequest) -> None: self._add_files_to_queue([file_info], mode="front") stats.incr("dag_processing.other_callback_count") - @provide_session - def get_bundle_state(self, bundle_name: str, *, session: Session = NEW_SESSION) -> BundleState | None: + def get_bundle_state(self, bundle_name: str) -> BundleState | None: """ - Return the persisted refresh state for a bundle. + Return the persisted refresh state for a bundle, via the DAG Processing API. - Returns ``None`` if the bundle has no database record. + Returns ``None`` if the bundle has no record. """ - row = session.scalar( - select(DagBundleModel) - .where(DagBundleModel.name == bundle_name) - .options(load_only(DagBundleModel.last_refreshed, DagBundleModel.version)) - ) - if row is None: + data = self._dag_processing_client.get_bundle_state(bundle_name) + if data is None: return None - return BundleState(last_refreshed=row.last_refreshed, version=row.version) + return BundleState(last_refreshed=data["last_refreshed"], version=data["version"]) - @provide_session - def update_bundle_state( - self, - bundle_name: str, - *, - last_refreshed: datetime, - version: str | None, - session: Session = NEW_SESSION, - ) -> None: + def update_bundle_state(self, bundle_name: str, *, last_refreshed: datetime, version: str | None) -> None: """ - Persist the post-refresh state for a bundle. + Persist the post-refresh state for a bundle, via the DAG Processing API. - Always updates ``last_refreshed``. Updates ``version`` only when ``version`` is not - ``None`` — pass ``None`` to leave the stored version unchanged (e.g. for non-versioned - bundles or when the version did not change after a refresh). + Always updates ``last_refreshed``; updates ``version`` only when ``version`` is not + ``None`` (pass ``None`` to leave the stored version unchanged). """ - values: dict[str, Any] = {"last_refreshed": last_refreshed} - if version is not None: - values["version"] = version - session.execute(update(DagBundleModel).where(DagBundleModel.name == bundle_name).values(**values)) + self._dag_processing_client.update_bundle_state( + bundle_name, last_refreshed=last_refreshed, version=version + ) def purge_inactive_dag_warnings(self) -> None: - """ - Purge warnings for inactive/stale DAGs. - - Default implementation deletes records from the metadata DB; override to - source warnings from an API or skip the cleanup entirely. - """ - DagWarning.purge_inactive_dag_warnings() + """Purge warnings for inactive/stale DAGs, via the DAG Processing API.""" + try: + self._dag_processing_client.purge_inactive_dag_warnings() + except Exception: + # A transient API outage must not crash the parse loop; the next cycle retries. + self.log.exception("Error purging inactive DAG warnings via the DAG Processing API") def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]): """Refresh DAG bundles, if required.""" @@ -846,11 +826,17 @@ def _refresh_dag_bundles(self, known_files: dict[str, set[DagFileInfo]]): known_files[bundle.name] = found_files - self.deactivate_deleted_dags(bundle_name=bundle.name, present=found_files) - self.clear_orphaned_import_errors( - bundle_name=bundle.name, - observed_filelocs=self._get_observed_filelocs(found_files), - ) + try: + self._dag_processing_client.reconcile( + bundle_name=bundle.name, + observed_filelocs=self._get_observed_filelocs(found_files), + ) + except Exception: + self.log.exception( + "Error reconciling bundle %s via the DAG Processing API; " + "skipping stale reconciliation this cycle", + bundle.name, + ) if any_refreshed: self.handle_removed_files(known_files=known_files) @@ -899,20 +885,6 @@ def find_zipped_dags(abs_path: os.PathLike) -> Iterator[str]: return observed_filelocs - def deactivate_deleted_dags(self, bundle_name: str, present: set[DagFileInfo]) -> None: - """Deactivate DAGs that come from files that are no longer present in bundle.""" - observed_filelocs = self._get_observed_filelocs(present) - with create_session() as session: - any_deactivated = DagModel.deactivate_deleted_dags( - bundle_name=bundle_name, - rel_filelocs=observed_filelocs, - session=session, - ) - # Only run cleanup if we actually deactivated any DAGs - # This avoids unnecessary DELETE queries in the common case where no DAGs were deleted - if any_deactivated: - remove_references_to_deleted_dags(session=session) - def print_stats(self, known_files: dict[str, set[DagFileInfo]]): """Occasionally print out stats about how fast the files are getting processed.""" if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time: @@ -920,28 +892,6 @@ def print_stats(self, known_files: dict[str, set[DagFileInfo]]): self._log_file_processing_stats(known_files=known_files) self.last_stat_print_time = time.monotonic() - @provide_session - def clear_orphaned_import_errors( - self, bundle_name: str, observed_filelocs: set[str], *, session: Session = NEW_SESSION - ): - """ - Clear import errors for files that no longer exist. - - :param session: session for ORM operations - """ - self.log.debug("Removing old import errors") - try: - errors = session.scalars( - select(ParseImportError) - .where(ParseImportError.bundle_name == bundle_name) - .options(load_only(ParseImportError.filename)) - ) - for error in errors: - if error.filename not in observed_filelocs: - session.delete(error) - except Exception: - self.log.exception("Error removing old import errors") - def _log_file_processing_stats(self, known_files: dict[str, set[DagFileInfo]]): """ Print out stats about how files are getting processed. @@ -1091,13 +1041,10 @@ def terminate_orphan_processes(self, present: set[DagFileInfo]): processor.logger_filehandle.close() self._file_stats.pop(file, None) - @provide_session def handle_parsing_result( self, file: DagFileInfo, proc: DagFileProcessorProcess, - *, - session: Session = NEW_SESSION, ) -> None: """ Post-process a single finished parse result. @@ -1107,11 +1054,6 @@ def handle_parsing_result( Extracted from ``_collect_results`` to keep result handling and persistence separate. - Owns its own DB session via ``@provide_session`` so subclasses that - forward results without touching the metadata DB (e.g. AIP-92 API-backed - deployments) can override this method without inheriting a session - created by the caller. - If persistence fails, the error is logged and the previous persisted DAG/import-error counts are preserved while a minimal timestamp update throttles immediate retries, so other files in the same @@ -1142,7 +1084,6 @@ def handle_parsing_result( parsing_result=proc.parsing_result, run_duration=run_duration, relative_fileloc=str(file.rel_path), - session=session, ) except Exception: self.log.exception( @@ -1174,36 +1115,15 @@ def persist_parsing_result( parsing_result: DagFileParsingResult, run_duration: float, relative_fileloc: str | None, - session: Session, ) -> None: - """Persist parsed DAG data to the metadata database.""" - import_errors: dict[tuple[str, str], str] = {} - if parsing_result.import_errors: - import_errors = { - (bundle_name, rel_path): error for rel_path, error in parsing_result.import_errors.items() - } - - # Build the set of files that were parsed. This includes the file that was parsed, - # even if it no longer contains DAGs, so we can clear old import errors. - files_parsed: set[tuple[str, str]] | None = None - if relative_fileloc is not None: - files_parsed = {(bundle_name, relative_fileloc)} - files_parsed.update(import_errors.keys()) - - warnings = parsing_result.warnings or [] - if warnings and isinstance(warnings[0], dict): - warnings = [DagWarning(**warn) for warn in warnings] - - update_dag_parsing_results_in_db( + """Persist parsed DAG data via the DAG Processing API.""" + self._dag_processing_client.persist_parsing_result( bundle_name=bundle_name, bundle_version=bundle_version, version_data=version_data, - dags=parsing_result.serialized_dags, - import_errors=import_errors, - parse_duration=run_duration, - warnings=set(warnings), - session=session, - files_parsed=files_parsed, + parsing_result=parsing_result, + run_duration=run_duration, + relative_fileloc=relative_fileloc, ) def _collect_results(self): @@ -1267,14 +1187,18 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo): underlying_logger, processors=processors, logger_name="processor" ).bind(), logger_filehandle - @functools.cached_property + @property def client(self) -> Client: - from airflow.sdk.api.client import Client - - client = Client(base_url=None, token="", dry_run=True, transport=self._api_server.transport) - # Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str` - client.base_url = "http://in-process.invalid./" - return client + # Parse-time connection/variable/xcom reads go to the remote Execution API, so the + # processor holds no metadata-DB connection. It carries the externally-provisioned token + # (see _api_token); it does not mint one, since it parses user code. Not cached: the token + # is re-read here so each parser process spawned later carries a token rotated on disk by + # the deployment, rather than a stale one baked in at manager start-up. + return Client( + base_url=get_execution_api_server_url(), + token=_api_token() or "", + dry_run=False, + ) def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess: id = uuid7() diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py b/airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py new file mode 100644 index 0000000000000..67400109d51b7 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/auth/test_dag_processor_token.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.api_fastapi.auth import dag_processor_token + +from tests_common.test_utils.config import conf_vars + + +class TestMintDagProcessorToken: + @mock.patch.object(dag_processor_token, "get_signing_args", return_value={"secret_key": "s"}) + @mock.patch.object(dag_processor_token, "JWTGenerator") + def test_carries_both_audiences_and_sentinel_subject(self, mock_generator, _signing_args): + mock_generator.return_value.generate.return_value = "minted-token" + with conf_vars( + { + ("execution_api", "jwt_audience"): "exec-aud", + ("dag_processor", "jwt_audience"): "dag-proc-aud", + ("dag_processor", "jwt_expiration_time"): "123", + } + ): + token = dag_processor_token.mint_dag_processor_token() + + assert token == "minted-token" + # Both audiences, since the processor presents this one token to both APIs. + assert mock_generator.call_args.kwargs["audience"] == ["exec-aud", "dag-proc-aud"] + assert mock_generator.call_args.kwargs["valid_for"] == 123 + mock_generator.return_value.generate.assert_called_once_with( + {"sub": dag_processor_token.DAG_PROCESSOR_TOKEN_SUBJECT} + ) + + @mock.patch.object(dag_processor_token, "get_signing_args", return_value={"secret_key": "s"}) + @mock.patch.object(dag_processor_token, "JWTGenerator") + def test_explicit_valid_for_overrides_config(self, mock_generator, _signing_args): + with conf_vars( + { + ("execution_api", "jwt_audience"): "exec-aud", + ("dag_processor", "jwt_audience"): "dag-proc-aud", + } + ): + dag_processor_token.mint_dag_processor_token(valid_for=42) + assert mock_generator.call_args.kwargs["valid_for"] == 42 + + +class TestProvisionDagProcessorTokenFile: + @mock.patch.object(dag_processor_token, "mint_dag_processor_token", return_value="minted") + def test_writes_to_explicit_path_and_creates_parents(self, _mint, tmp_path): + target = tmp_path / "nested" / "dag-processor-token" + + returned = dag_processor_token.provision_dag_processor_token_file(target) + + assert returned == str(target) + assert target.read_text() == "minted" + # Atomic write leaves no temp file behind. + assert not (target.parent / "dag-processor-token.tmp").exists() + + @mock.patch.object(dag_processor_token, "mint_dag_processor_token", return_value="minted") + def test_defaults_to_configured_api_token_path(self, _mint, tmp_path): + target = tmp_path / "token" + with conf_vars({("dag_processor", "api_token_path"): str(target)}): + dag_processor_token.provision_dag_processor_token_file() + assert target.read_text() == "minted" + + def test_raises_when_no_path_configured(self): + with conf_vars({("dag_processor", "api_token_path"): ""}): + with pytest.raises(ValueError, match="api_token_path"): + dag_processor_token.provision_dag_processor_token_file(None) diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py b/airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py new file mode 100644 index 0000000000000..21d298ede6ed3 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py new file mode 100644 index 0000000000000..f1a44427c8ab0 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app.py @@ -0,0 +1,383 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from fastapi.testclient import TestClient + +from airflow.api_fastapi.dag_processing.app import create_dag_processing_api_app +from airflow.api_fastapi.dag_processing.security import require_dag_processing_auth + +from tests_common.test_utils.config import conf_vars + +APP = "airflow.api_fastapi.dag_processing.app" + + +@pytest.fixture +def client(): + # Endpoint-logic tests bypass the token check; auth itself is covered separately below. + app = create_dag_processing_api_app() + app.dependency_overrides[require_dag_processing_auth] = lambda: None + return TestClient(app) + + +def test_health(client): + resp = client.get("/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "healthy"} + + +def test_persistence_endpoint_requires_a_token(): + """A persistence endpoint rejects an unauthenticated call (no token).""" + unauth = TestClient(create_dag_processing_api_app(), raise_server_exceptions=False) + resp = unauth.post("/bundles/sync") + assert resp.status_code == 401 + + +def test_persistence_endpoint_rejects_invalid_token(): + """A persistence endpoint rejects a malformed/garbage token.""" + unauth = TestClient(create_dag_processing_api_app(), raise_server_exceptions=False) + resp = unauth.post("/bundles/sync", headers={"Authorization": "Bearer not-a-real-token"}) + assert resp.status_code == 403 + + +def test_health_does_not_require_a_token(): + """``/health`` stays open so external readiness probes work.""" + unauth = TestClient(create_dag_processing_api_app()) + assert unauth.get("/health").status_code == 200 + + +def test_persistence_endpoint_accepts_a_valid_token(): + """A token signed with the deployment key for the DAG-processing audience is accepted.""" + from airflow.api_fastapi.auth.tokens import JWTGenerator, get_signing_args + from airflow.api_fastapi.dag_processing.security import _token_validator + from airflow.configuration import conf + + with conf_vars({("api_auth", "jwt_secret"): "dag-processing-test-secret"}): + _token_validator.cache_clear() + # A trusted component (not the DAG processor) mints the token the processor presents. + # Mirror the audience/issuer the validator reads so the token is accepted. + token = JWTGenerator( + valid_for=300, + audience=conf.get_mandatory_list_value("dag_processor", "jwt_audience")[0], + issuer=conf.get("api_auth", "jwt_issuer", fallback=None), + **get_signing_args(make_secret_key_if_needed=False), + ).generate({"sub": "dag-processor"}) + authed = TestClient(create_dag_processing_api_app(), raise_server_exceptions=False) + with mock.patch(f"{APP}.DagBundlesManager"): + resp = authed.post("/bundles/sync", headers={"Authorization": f"Bearer {token}"}) + _token_validator.cache_clear() + assert resp.status_code == 200 + + +def test_persist_parsing_results_forwards_to_update_db(client): + with ( + mock.patch(f"{APP}.update_dag_parsing_results_in_db") as update_db, + mock.patch(f"{APP}.LazyDeserializedDAG") as lazy_dag, + mock.patch(f"{APP}.create_session"), + ): + lazy_dag.model_validate.side_effect = lambda d: d + body = { + "bundle_name": "b1", + "bundle_version": "v1", + "relative_fileloc": "dags/a.py", + "run_duration": 1.5, + "serialized_dags": [{"data": {"dag": {"dag_id": "x"}}}], + "import_errors": {"dags/a.py": "boom"}, + "warnings": [], + } + resp = client.post("/parsing-results", json=body) + + assert resp.status_code == 201 + update_db.assert_called_once() + kwargs = update_db.call_args.kwargs + assert kwargs["bundle_name"] == "b1" + assert kwargs["bundle_version"] == "v1" + # import errors are re-keyed to (bundle_name, relative_fileloc) tuples + assert kwargs["import_errors"] == {("b1", "dags/a.py"): "boom"} + # files_parsed includes the parsed file plus any import-error files + assert ("b1", "dags/a.py") in kwargs["files_parsed"] + + +def test_persist_parsing_results_without_fileloc_has_no_files_parsed(client): + with ( + mock.patch(f"{APP}.update_dag_parsing_results_in_db") as update_db, + mock.patch(f"{APP}.LazyDeserializedDAG"), + mock.patch(f"{APP}.create_session"), + ): + resp = client.post("/parsing-results", json={"bundle_name": "b1", "serialized_dags": []}) + + assert resp.status_code == 201 + assert update_db.call_args.kwargs["files_parsed"] is None + + +def test_persist_parsing_results_invalid_payload_returns_422(client): + with ( + mock.patch(f"{APP}.LazyDeserializedDAG") as lazy_dag, + mock.patch(f"{APP}.update_dag_parsing_results_in_db") as update_db, + ): + lazy_dag.model_validate.side_effect = ValueError("bad blob") + resp = client.post("/parsing-results", json={"bundle_name": "b1", "serialized_dags": [{"data": {}}]}) + + assert resp.status_code == 422 + update_db.assert_not_called() + + +def test_reconcile_deactivates_and_removes_references(client): + with ( + mock.patch(f"{APP}.DagModel") as dag_model, + mock.patch(f"{APP}.remove_references_to_deleted_dags") as remove_refs, + mock.patch(f"{APP}.create_session") as create_session, + ): + dag_model.deactivate_deleted_dags.return_value = True + create_session.return_value.__enter__.return_value.scalars.return_value = [] + resp = client.post("/bundles/b1/reconcile", json={"observed_filelocs": ["dags/a.py"]}) + + assert resp.status_code == 200 + assert resp.json() == {"bundle_name": "b1", "deactivated": True} + dag_model.deactivate_deleted_dags.assert_called_once() + remove_refs.assert_called_once() + + +def test_reconcile_skips_remove_references_when_nothing_deactivated(client): + with ( + mock.patch(f"{APP}.DagModel") as dag_model, + mock.patch(f"{APP}.remove_references_to_deleted_dags") as remove_refs, + mock.patch(f"{APP}.create_session") as create_session, + ): + dag_model.deactivate_deleted_dags.return_value = False + create_session.return_value.__enter__.return_value.scalars.return_value = [] + resp = client.post("/bundles/b1/reconcile", json={"observed_filelocs": []}) + + assert resp.status_code == 200 + assert resp.json()["deactivated"] is False + remove_refs.assert_not_called() + + +def test_reconcile_import_error_cleanup_failure_is_swallowed(client): + """A cleanup failure must not fail the request nor (being a separate txn) undo deactivation.""" + with ( + mock.patch(f"{APP}.DagModel") as dag_model, + mock.patch(f"{APP}.remove_references_to_deleted_dags"), + mock.patch(f"{APP}.create_session") as create_session, + ): + dag_model.deactivate_deleted_dags.return_value = True + # The import-error cleanup block calls session.scalars(); make it blow up. + create_session.return_value.__enter__.return_value.scalars.side_effect = RuntimeError("db down") + resp = client.post("/bundles/b1/reconcile", json={"observed_filelocs": ["dags/a.py"]}) + + assert resp.status_code == 200 + # Deactivation still ran (in its own, earlier transaction). + dag_model.deactivate_deleted_dags.assert_called_once() + + +def test_get_bundle_state_found(client): + from datetime import datetime, timezone + + row = mock.MagicMock(last_refreshed=datetime(2024, 1, 1, tzinfo=timezone.utc), version="v1") + with ( + mock.patch(f"{APP}.create_session") as create_session, + ): + create_session.return_value.__enter__.return_value.scalar.return_value = row + resp = client.get("/bundles/b1/state") + + assert resp.status_code == 200 + body = resp.json() + assert body["found"] is True + assert body["version"] == "v1" + + +def test_get_bundle_state_not_found(client): + with ( + mock.patch(f"{APP}.create_session") as create_session, + ): + create_session.return_value.__enter__.return_value.scalar.return_value = None + resp = client.get("/bundles/missing/state") + + assert resp.status_code == 200 + assert resp.json() == {"found": False, "last_refreshed": None, "version": None} + + +def test_update_bundle_state_includes_version_when_provided(client): + with ( + mock.patch(f"{APP}.update") as update_stmt, + mock.patch(f"{APP}.create_session") as create_session, + ): + resp = client.patch( + "/bundles/b1/state", + json={"last_refreshed": "2024-06-01T08:00:00+00:00", "version": "abc"}, + ) + + assert resp.status_code == 200 + # version supplied -> both columns written + update_stmt.return_value.where.return_value.values.assert_called_once_with( + last_refreshed=mock.ANY, version="abc" + ) + create_session.return_value.__enter__.return_value.execute.assert_called_once() + + +def test_update_bundle_state_omits_version_when_null(client): + with ( + mock.patch(f"{APP}.update") as update_stmt, + mock.patch(f"{APP}.create_session"), + ): + resp = client.patch( + "/bundles/b1/state", + json={"last_refreshed": "2024-06-01T08:00:00+00:00", "version": None}, + ) + + assert resp.status_code == 200 + # version omitted -> only last_refreshed written (stored version left unchanged) + update_stmt.return_value.where.return_value.values.assert_called_once_with(last_refreshed=mock.ANY) + + +def test_sync_bundles_triggers_server_side_sync(client): + with mock.patch(f"{APP}.DagBundlesManager") as dbm: + resp = client.post("/bundles/sync") + + assert resp.status_code == 200 + dbm.return_value.sync_bundles_to_db.assert_called_once_with() + + +def test_deactivate_stale_dags_inactive_bundle_and_threshold(client): + from datetime import datetime, timedelta, timezone + + now = datetime(2024, 6, 1, 12, 0, 0, tzinfo=timezone.utc) + r_inactive = mock.MagicMock( + dag_id="d_inactive", bundle_name="dead", last_parsed_time=now, relative_fileloc="a.py" + ) + r_stale = mock.MagicMock( + dag_id="d_stale", + bundle_name="live", + last_parsed_time=now - timedelta(hours=2), + relative_fileloc="b.py", + ) + r_fresh = mock.MagicMock( + dag_id="d_fresh", bundle_name="live", last_parsed_time=now, relative_fileloc="c.py" + ) + with mock.patch(f"{APP}.create_session") as create_session: + session = create_session.return_value.__enter__.return_value + session.scalars.return_value.all.return_value = ["dead"] + session.execute.side_effect = [[r_inactive, r_stale, r_fresh], mock.MagicMock(rowcount=2)] + body = { + "stale_dag_threshold": 60, + "last_parsed": [ + {"bundle_name": "live", "relative_fileloc": "b.py", "last_finish_time": now.isoformat()}, + {"bundle_name": "live", "relative_fileloc": "c.py", "last_finish_time": now.isoformat()}, + ], + } + resp = client.post("/stale-dags", json=body) + + assert resp.status_code == 200 + # d_inactive (dead bundle) + d_stale (past threshold); d_fresh stays. + assert resp.json()["deactivated"] == 2 + + +def test_purge_warnings_endpoint(client): + with mock.patch(f"{APP}.DagWarning") as dag_warning: + resp = client.post("/purge-warnings") + + assert resp.status_code == 200 + dag_warning.purge_inactive_dag_warnings.assert_called_once_with() + + +def test_claim_priority_parse_requests(client): + req = mock.MagicMock(bundle_name="b1", relative_fileloc="dags/a.py") + with mock.patch(f"{APP}.create_session") as create_session: + session = create_session.return_value.__enter__.return_value + session.scalars.return_value = [req] + resp = client.post("/priority-parse-requests/claim", json={"bundle_names": ["b1"]}) + + assert resp.status_code == 200 + assert resp.json()["claimed"] == [{"bundle_name": "b1", "relative_fileloc": "dags/a.py"}] + session.delete.assert_called_once_with(req) + + +def test_claim_callbacks_skip_locked_and_delete(client): + cb = mock.MagicMock(data={"req_class": "DagCallbackRequest", "req_data": "{}"}) + with ( + mock.patch(f"{APP}.with_row_locks"), + mock.patch(f"{APP}.prohibit_commit") as prohibit, + mock.patch(f"{APP}.create_session") as create_session, + ): + session = create_session.return_value.__enter__.return_value + session.scalars.return_value = [cb] + resp = client.post("/callbacks/claim", json={"bundle_names": ["b1"], "limit": 5}) + + assert resp.status_code == 200 + assert resp.json()["callbacks"] == [{"req_class": "DagCallbackRequest", "req_data": "{}"}] + session.delete.assert_called_once_with(cb) + # the claim is committed via the prohibit_commit guard + prohibit.return_value.__enter__.return_value.commit.assert_called_once() + + +def test_register_job(client): + job = mock.MagicMock(id=42) + with mock.patch(f"{APP}.Job", return_value=job), mock.patch(f"{APP}.create_session"): + resp = client.post( + "/jobs", + json={ + "job_type": "DagProcessorJob", + "hostname": "dag-proc-host", + "unixname": "airflow", + "pid": 1234, + }, + ) + + assert resp.status_code == 201 + assert resp.json() == {"job_id": 42} + job.prepare_for_execution.assert_called_once() + # The processor's reported identity is recorded (not this server's), so its health check matches. + assert job.hostname == "dag-proc-host" + assert job.pid == 1234 + + +def test_register_job_requires_hostname(client): + """hostname is required so the Job row reflects the processor, not the API server.""" + resp = client.post("/jobs", json={"job_type": "DagProcessorJob"}) + assert resp.status_code == 422 + + +def test_job_heartbeat_updates_latest_heartbeat(client): + job = mock.MagicMock() + with mock.patch(f"{APP}.create_session") as create_session: + create_session.return_value.__enter__.return_value.get.return_value = job + resp = client.post("/jobs/7/heartbeat") + + assert resp.status_code == 200 + assert resp.json() == {"alive": True} + assert job.latest_heartbeat is not None + + +def test_job_heartbeat_missing_returns_404(client): + with mock.patch(f"{APP}.create_session") as create_session: + create_session.return_value.__enter__.return_value.get.return_value = None + resp = client.post("/jobs/7/heartbeat") + + assert resp.status_code == 404 + + +def test_complete_job_sets_state(client): + job = mock.MagicMock() + with mock.patch(f"{APP}.create_session") as create_session: + create_session.return_value.__enter__.return_value.get.return_value = job + resp = client.post("/jobs/7/complete", json={"state": "success"}) + + assert resp.status_code == 200 + assert job.state == "success" diff --git a/airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py new file mode 100644 index 0000000000000..1ef15ca3d0263 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/dag_processing/test_app_db.py @@ -0,0 +1,291 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Real-database tests for the DAG Processing API persistence endpoints. + +These complement ``test_app.py`` (which mocks ``create_session``/``with_row_locks``) +by running the endpoints against a real session, so the actual SQL -- ordering, +limit, WHERE clauses, lock, and delete -- is exercised end to end. +""" + +from __future__ import annotations + +import json +from datetime import timedelta + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import select, update + +from airflow._shared.timezones import timezone +from airflow.api_fastapi.dag_processing.app import create_dag_processing_api_app +from airflow.api_fastapi.dag_processing.security import require_dag_processing_auth +from airflow.callbacks.callback_requests import DagCallbackRequest +from airflow.models.dag import DagModel +from airflow.models.dagbag import DagPriorityParsingRequest +from airflow.models.dagbundle import DagBundleModel +from airflow.models.db_callback_request import DbCallbackRequest +from airflow.models.errors import ParseImportError +from airflow.models.serialized_dag import SerializedDagModel + +from tests_common.test_utils.db import ( + clear_db_callbacks, + clear_db_dag_bundles, + clear_db_dag_parsing_requests, + clear_db_dags, + clear_db_import_errors, + clear_db_serialized_dags, +) + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def client(): + # The token check is covered in test_app.py; here we bypass it and exercise the real DB. + app = create_dag_processing_api_app() + app.dependency_overrides[require_dag_processing_auth] = lambda: None + return TestClient(app) + + +def _make_callback(*, bundle_name: str, run_id: str) -> DagCallbackRequest: + return DagCallbackRequest( + dag_id="test_dag", + bundle_name=bundle_name, + bundle_version=None, + filepath="some_dag.py", + is_failure_callback=True, + run_id=run_id, + ) + + +class TestClaimCallbacks: + def setup_method(self): + clear_db_callbacks() + + def teardown_method(self): + clear_db_callbacks() + + def test_claim_orders_by_priority_honors_limit_filters_bundle_and_deletes(self, client, session): + """POST /callbacks/claim returns the owned bundle's callbacks ordered by priority_weight + DESC up to ``limit``, deletes the claimed rows, and leaves the other bundle's row and the + rows beyond the limit untouched.""" + # Arrange: three callbacks for the owned bundle (varying priority) + one for another bundle. + session.add( + DbCallbackRequest(callback=_make_callback(bundle_name="owned", run_id="low"), priority_weight=1) + ) + session.add( + DbCallbackRequest(callback=_make_callback(bundle_name="owned", run_id="high"), priority_weight=10) + ) + session.add( + DbCallbackRequest(callback=_make_callback(bundle_name="owned", run_id="mid"), priority_weight=5) + ) + session.add( + DbCallbackRequest( + callback=_make_callback(bundle_name="other", run_id="other"), priority_weight=100 + ) + ) + session.commit() + + # Act: claim from the owned bundle with a limit of 2. + resp = client.post("/callbacks/claim", json={"bundle_names": ["owned"], "limit": 2}) + + # Assert: highest-priority owned callbacks first, count capped at the limit. + assert resp.status_code == 200 + claimed = resp.json()["callbacks"] + assert len(claimed) == 2 + run_ids = [json.loads(cb["req_data"])["run_id"] for cb in claimed] + assert run_ids == ["high", "mid"] + + # The claimed rows are deleted; the lower-priority owned row and the other bundle's row remain. + remaining = session.scalars(select(DbCallbackRequest)).all() + remaining_run_ids = {json.loads(cb.data["req_data"])["run_id"] for cb in remaining} + assert remaining_run_ids == {"low", "other"} + + +class TestDeactivateStaleDags: + def setup_method(self): + clear_db_serialized_dags() + clear_db_dags() + clear_db_dag_bundles() + + def teardown_method(self): + clear_db_serialized_dags() + clear_db_dags() + clear_db_dag_bundles() + + def test_marks_inactive_bundle_and_past_threshold_dags_keeps_fresh_and_serialized( + self, client, dag_maker, session + ): + """POST /stale-dags marks the inactive-bundle DAG and the past-threshold DAG stale, leaves + the freshly parsed DAG active, and preserves the SerializedDagModel row.""" + now = timezone.utcnow() + threshold = 60 + + # A freshly parsed DAG with a real SerializedDagModel, in an active bundle. + with dag_maker(dag_id="fresh_dag", bundle_name="live", serialized=True, session=session): + pass + dag_maker.dag_model.last_parsed_time = now + dag_maker.dag_model.relative_fileloc = "fresh.py" + dag_maker.dag_model.is_stale = False + session.merge(dag_maker.dag_model) + + # An inactive bundle plus a DAG that lives in it. + session.add(DagBundleModel(name="gone-bundle")) + session.flush() + session.execute( + update(DagBundleModel).where(DagBundleModel.name == "gone-bundle").values(active=False) + ) + session.add( + DagModel( + dag_id="dag_in_inactive_bundle", + bundle_name="gone-bundle", + relative_fileloc="inactive.py", + last_parsed_time=now, + is_stale=False, + ) + ) + # A DAG in the active bundle whose last parse predates last_finish_time by more than the threshold. + session.add( + DagModel( + dag_id="dag_past_threshold", + bundle_name="live", + relative_fileloc="stale.py", + last_parsed_time=now - timedelta(hours=2), + is_stale=False, + ) + ) + session.commit() + + # Act: last_finish_time for the live-bundle files is "now"; fresh.py's gap is under the + # threshold, stale.py's gap (2h) is over it. + body = { + "stale_dag_threshold": threshold, + "last_parsed": [ + {"bundle_name": "live", "relative_fileloc": "fresh.py", "last_finish_time": now.isoformat()}, + {"bundle_name": "live", "relative_fileloc": "stale.py", "last_finish_time": now.isoformat()}, + ], + } + resp = client.post("/stale-dags", json=body) + + # Assert: two DAGs deactivated; the fresh DAG stays active. + assert resp.status_code == 200 + assert resp.json()["deactivated"] == 2 + + is_stale_by_dag = dict( + session.execute( + select(DagModel.dag_id, DagModel.is_stale).where( + DagModel.dag_id.in_(["fresh_dag", "dag_in_inactive_bundle", "dag_past_threshold"]) + ) + ).all() + ) + assert is_stale_by_dag == { + "fresh_dag": False, + "dag_in_inactive_bundle": True, + "dag_past_threshold": True, + } + + # Deactivating DagModel rows must not delete the SerializedDagModel history. + serialized_count = session.scalar( + select(SerializedDagModel.dag_id).where(SerializedDagModel.dag_id == "fresh_dag") + ) + assert serialized_count == "fresh_dag" + + +class TestReconcileImportErrors: + def setup_method(self): + clear_db_import_errors() + clear_db_serialized_dags() + clear_db_dags() + + def teardown_method(self): + clear_db_import_errors() + clear_db_serialized_dags() + clear_db_dags() + + def test_deletes_absent_file_error_but_keeps_observed_and_zip_inner_errors(self, client, session): + """POST /bundles/{name}/reconcile deletes the import error for a file no longer observed, + while retaining the error for a still-observed file and for a zip inner path that is + observed (e.g. ``test_zip.zip/broken_dag.py``).""" + bundle = "testing" + # One error for a normal file that is still present, one zip-inner error that is still + # observed, and one for a file that is now absent from the bundle. + session.add( + ParseImportError( + filename="present_dag.py", + bundle_name=bundle, + timestamp=timezone.utcnow(), + stacktrace="present error", + ) + ) + session.add( + ParseImportError( + filename="test_zip.zip/broken_dag.py", + bundle_name=bundle, + timestamp=timezone.utcnow(), + stacktrace="zip import error", + ) + ) + session.add( + ParseImportError( + filename="absent_dag.py", + bundle_name=bundle, + timestamp=timezone.utcnow(), + stacktrace="absent error", + ) + ) + session.commit() + + # Act: observed set includes the normal file and the zip inner path, but not absent_dag.py. + resp = client.post( + f"/bundles/{bundle}/reconcile", + json={"observed_filelocs": ["present_dag.py", "test_zip.zip/broken_dag.py"]}, + ) + + # Assert: only the absent file's error is removed. + assert resp.status_code == 200 + remaining = { + err.filename + for err in session.scalars(select(ParseImportError).where(ParseImportError.bundle_name == bundle)) + } + assert remaining == {"present_dag.py", "test_zip.zip/broken_dag.py"} + + +class TestClaimPriorityParseRequests: + def setup_method(self): + clear_db_dag_parsing_requests() + + def teardown_method(self): + clear_db_dag_parsing_requests() + + def test_claims_owned_bundle_request_and_leaves_other_bundle_request(self, client, session): + """POST /priority-parse-requests/claim returns and deletes the owned bundle's request while + leaving a request that belongs to a bundle this processor does not own.""" + session.add(DagPriorityParsingRequest(bundle_name="owned", relative_fileloc="owned_file.py")) + session.add(DagPriorityParsingRequest(bundle_name="other", relative_fileloc="other_file.py")) + session.commit() + + # Act: claim only the owned bundle. + resp = client.post("/priority-parse-requests/claim", json={"bundle_names": ["owned"]}) + + # Assert: the owned request is returned and deleted; the other bundle's request remains. + assert resp.status_code == 200 + assert resp.json()["claimed"] == [{"bundle_name": "owned", "relative_fileloc": "owned_file.py"}] + + remaining = session.scalars(select(DagPriorityParsingRequest)).all() + assert len(remaining) == 1 + assert remaining[0].bundle_name == "other" + assert remaining[0].relative_fileloc == "other_file.py" diff --git a/airflow-core/tests/unit/cli/commands/test_standalone_command.py b/airflow-core/tests/unit/cli/commands/test_standalone_command.py index 993f51ca2d72d..4149c5928a94e 100644 --- a/airflow-core/tests/unit/cli/commands/test_standalone_command.py +++ b/airflow-core/tests/unit/cli/commands/test_standalone_command.py @@ -20,6 +20,7 @@ import os from collections import deque from importlib import reload +from pathlib import Path from unittest import mock import pytest @@ -32,6 +33,8 @@ LOCAL_EXECUTOR, ) +from tests_common.test_utils.config import conf_vars + class TestStandaloneCommand: @pytest.mark.parametrize( @@ -82,6 +85,31 @@ class FakeExecutor: assert "AIRFLOW__CORE__AUTH_MANAGER" not in env + @mock.patch("airflow.api_fastapi.auth.dag_processor_token.JWTGenerator") + @mock.patch( + "airflow.api_fastapi.auth.dag_processor_token.get_signing_args", return_value={"secret_key": "s"} + ) + @mock.patch("airflow.cli.commands.standalone_command.get_signing_key", return_value="the-secret") + def test_provision_dag_processor_token(self, _get_key, _get_args, mock_generator): + """Standalone mints the processor's token and provisions it via env + a token file.""" + mock_generator.return_value.generate.return_value = "minted-token" + env: dict[str, str] = {} + with conf_vars( + { + ("execution_api", "jwt_audience"): "exec-aud", + ("dag_processor", "jwt_audience"): "dag-proc-aud", + } + ): + StandaloneCommand()._provision_dag_processor_token(env) + + # The processor never mints: standalone provisions the token + a shared signing key. + assert env["AIRFLOW__API_AUTH__JWT_SECRET"] == "the-secret" + token_path = env["AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH"] + assert Path(token_path).read_text() == "minted-token" + # Token is minted for both the Execution and DAG Processing audiences. + assert mock_generator.call_args.kwargs["audience"] == ["exec-aud", "dag-proc-aud"] + Path(token_path).unlink() + @mock.patch("airflow.cli.commands.standalone_command.os.path.exists", return_value=False) @mock.patch("airflow.cli.commands.standalone_command.create_auth_manager") @mock.patch("airflow.cli.commands.standalone_command.conf.get") diff --git a/airflow-core/tests/unit/dag_processing/test_api_client.py b/airflow-core/tests/unit/dag_processing/test_api_client.py new file mode 100644 index 0000000000000..c22f6c1d5b70e --- /dev/null +++ b/airflow-core/tests/unit/dag_processing/test_api_client.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime, timezone + +import httpx +import pytest + +from airflow.dag_processing.api_client import _CallableBearerAuth, _parse_iso_datetime + + +class TestParseIsoDatetime: + @pytest.mark.parametrize( + "value", + [ + "2026-06-03T03:52:13.938753Z", # trailing Z (server format; rejected by 3.10 fromisoformat) + "2026-06-03T03:52:13.938753+00:00", # explicit offset + ], + ) + def test_parses_utc_with_or_without_z(self, value): + parsed = _parse_iso_datetime(value) + assert parsed == datetime(2026, 6, 3, 3, 52, 13, 938753, tzinfo=timezone.utc) + assert parsed.utcoffset() == timezone.utc.utcoffset(None) + + +def _drive_auth_flow(auth: _CallableBearerAuth, statuses: list[int]) -> list[str | None]: + """Drive ``auth.auth_flow`` the way httpx does, feeding it the given response statuses. + + Returns the ``Authorization`` header value sent for each request (snapshotted at send time, + since httpx mutates the one request object in place on a 401-triggered retry). A retry shows up + as a second entry. + """ + gen = auth.auth_flow(httpx.Request("GET", "http://host/path")) + sent_auth: list[str | None] = [] + try: + request = next(gen) + for status in statuses: + sent_auth.append(request.headers.get("Authorization")) + request = gen.send(httpx.Response(status, request=request)) + except StopIteration: + pass + return sent_auth + + +class TestCallableBearerAuth: + def test_attaches_bearer_token_from_getter(self): + auth = _CallableBearerAuth(lambda: "abc") + assert _drive_auth_flow(auth, [200]) == ["Bearer abc"] + + def test_no_header_when_token_is_none(self): + auth = _CallableBearerAuth(lambda: None) + assert _drive_auth_flow(auth, [200]) == [None] + + def test_token_is_cached_within_ttl(self): + calls = [] + + def getter(): + calls.append(1) + return "t" + + auth = _CallableBearerAuth(getter, cache_ttl=1000.0) + _drive_auth_flow(auth, [200]) + _drive_auth_flow(auth, [200]) + # Second request reused the cached token rather than re-reading. + assert len(calls) == 1 + + def test_401_rereads_token_and_retries_once(self): + tokens = iter(["stale", "fresh"]) + auth = _CallableBearerAuth(lambda: next(tokens), cache_ttl=1000.0) + + # First request carries the stale token; the 401 re-reads (bypassing the cache) and retries + # with the fresh one. + assert _drive_auth_flow(auth, [401, 200]) == ["Bearer stale", "Bearer fresh"] + + def test_no_retry_when_reread_token_unchanged(self): + auth = _CallableBearerAuth(lambda: "same", cache_ttl=1000.0) + # Re-read returned the same token, so there's nothing to retry with. + assert _drive_auth_flow(auth, [401, 200]) == ["Bearer same"] diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 941d090f22994..9128b9472f14a 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -23,7 +23,6 @@ import os import random import re -import shutil import signal import textwrap import time @@ -38,14 +37,13 @@ import msgspec import pytest import time_machine -from sqlalchemy import func, select +from sqlalchemy import select, update from uuid6 import uuid7 from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import DagCallbackRequest from airflow.dag_processing.bundles.base import BaseDagBundle from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.dag_processing.dagbag import DagBag from airflow.dag_processing.manager import ( BundleState, DagFileInfo, @@ -53,19 +51,13 @@ DagFileStat, ) from airflow.dag_processing.processor import DagFileParsingResult, DagFileProcessorProcess -from airflow.models import DagModel, DbCallbackRequest -from airflow.models.asset import TaskOutletAssetReference -from airflow.models.dag_version import DagVersion +from airflow.models import DagModel from airflow.models.dagbundle import DagBundleModel -from airflow.models.dagcode import DagCode -from airflow.models.serialized_dag import SerializedDagModel from airflow.models.team import Team from airflow.utils.net import get_hostname from airflow.utils.session import create_session -from tests_common.test_utils.compat import ParseImportError from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.db import ( clear_db_assets, clear_db_callbacks, @@ -162,6 +154,52 @@ def _disable_examples(self): with conf_vars({("core", "load_examples"): "False"}): yield + @pytest.fixture(autouse=True) + def mock_dag_processing_client(self): + """Patch the DAG Processing API client so managers don't hit a real API server. + + ``DagFileProcessorManager`` now always holds a ``DagProcessingApiClient`` and routes + every persistence/metadata operation through it; in unit tests there is no API server, + so the real client raises ``httpx.ConnectError``. This replaces it with a mock whose + bundle-state read/write still round-trips through the metadata DB (so refresh-mechanic + tests behave as before), while persistence/reconcile/claim calls are inert. + + ``get_bundle_state`` falls back to a stale state when no ``DagBundleModel`` row exists. + In production the server creates that row during ``sync_bundles`` before parsing; here + ``sync_bundles`` is inert, so without the fallback every bundle would be skipped with + "Bundle model not found" and no files would ever be scanned. + + The mock is yielded so individual tests can override return values (e.g. + ``client.get_bundle_state.return_value = {"last_refreshed": ..., "version": ...}``). + """ + # A "stale" timestamp so bundles with refresh_interval=0 always refresh by default. + stale_last_refreshed = timezone.datetime(2000, 1, 1) + + def _db_get_bundle_state(bundle_name): + with create_session() as session: + row = session.scalar(select(DagBundleModel).where(DagBundleModel.name == bundle_name)) + if row is None: + return {"last_refreshed": stale_last_refreshed, "version": None} + return {"last_refreshed": row.last_refreshed, "version": row.version} + + def _db_update_bundle_state(bundle_name, *, last_refreshed, version): + values: dict = {"last_refreshed": last_refreshed} + if version is not None: + values["version"] = version + with create_session() as session: + session.execute( + update(DagBundleModel).where(DagBundleModel.name == bundle_name).values(**values) + ) + session.commit() + + with mock.patch("airflow.dag_processing.manager.DagProcessingApiClient") as mock_client_cls: + client = mock_client_cls.return_value + client.get_bundle_state.side_effect = _db_get_bundle_state + client.update_bundle_state.side_effect = _db_update_bundle_state + client.claim_priority_files.return_value = [] + client.fetch_callbacks.return_value = [] + yield client + def setup_method(self): clear_db_teams() clear_db_assets() @@ -204,76 +242,6 @@ def mock_processor(self, start_time: float | None = None) -> tuple[DagFileProces ret._open_sockets.clear() return ret, read_end - @pytest.fixture - def clear_parse_import_errors(self): - clear_db_import_errors() - - @pytest.mark.usefixtures("clear_parse_import_errors") - @conf_vars({("core", "load_examples"): "False"}) - def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_bundle): - path_to_parse = tmp_path / "temp_dag.py" - - # Generate original import error - path_to_parse.write_text("an invalid airflow DAG") - - with configure_testing_dag_bundle(path_to_parse): - manager = DagFileProcessorManager( - max_runs=1, - processor_timeout=365 * 86_400, - ) - - manager.run() - - with create_session() as session: - import_errors = session.scalars(select(ParseImportError)).all() - assert len(import_errors) == 1 - - path_to_parse.unlink() - - # Rerun the parser once the dag file has been removed - manager.run() - - with create_session() as session: - import_errors = session.scalars(select(ParseImportError)).all() - - assert len(import_errors) == 0 - session.rollback() - - @pytest.mark.usefixtures("clear_parse_import_errors") - def test_clear_orphaned_import_errors_keeps_zip_inner_file_errors(self, session, tmp_path): - zip_path = tmp_path / "test_zip.zip" - _create_zip_bundle_with_valid_and_broken_dags(zip_path) - - session.add( - ParseImportError( - filename="test_zip.zip/broken_dag.py", - bundle_name="testing", - timestamp=timezone.utcnow(), - stacktrace="zip import error", - ) - ) - session.flush() - - manager = DagFileProcessorManager(max_runs=1) - manager.clear_orphaned_import_errors( - bundle_name="testing", - observed_filelocs=manager._get_observed_filelocs( - { - DagFileInfo( - bundle_name="testing", - rel_path=Path("test_zip.zip"), - bundle_path=tmp_path, - ) - } - ), - session=session, - ) - session.flush() - - import_errors = session.scalars(select(ParseImportError)).all() - assert len(import_errors) == 1 - assert import_errors[0].filename == "test_zip.zip/broken_dag.py" - def test_get_observed_filelocs_expands_zip_inner_paths(self, tmp_path): zip_path = tmp_path / "test_zip.zip" _create_zip_bundle_with_valid_and_broken_dags(zip_path) @@ -294,63 +262,24 @@ def test_get_observed_filelocs_expands_zip_inner_paths(self, tmp_path): "test_zip.zip/broken_dag.py", } - @pytest.mark.usefixtures("clear_parse_import_errors") - def test_refresh_dag_bundles_keeps_zip_inner_file_errors(self, session, tmp_path, configure_dag_bundles): - bundle_path = tmp_path / "bundleone" - bundle_path.mkdir() - zip_path = bundle_path / "test_zip.zip" - _create_zip_bundle_with_valid_and_broken_dags(zip_path) - - session.add( - ParseImportError( - filename="test_zip.zip/broken_dag.py", - bundle_name="bundleone", - timestamp=timezone.utcnow(), - stacktrace="zip import error", - ) - ) - session.flush() - - with configure_dag_bundles({"bundleone": bundle_path}): - DagBundlesManager().sync_bundles_to_db() - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - manager._refresh_dag_bundles({}) - - import_errors = session.scalars(select(ParseImportError)).all() - assert len(import_errors) == 1 - assert import_errors[0].filename == "test_zip.zip/broken_dag.py" - - def test_refresh_dag_bundles_calls_legacy_deactivate_deleted_dags_override( - self, tmp_path, configure_dag_bundles + def test_refresh_dag_bundles_reconciles_via_client( + self, tmp_path, configure_dag_bundles, mock_dag_processing_client ): bundle_path = tmp_path / "bundleone" bundle_path.mkdir() dag_path = bundle_path / "test_dag.py" dag_path.write_text("from airflow.sdk import DAG\n") - class BackwardCompatibleManager(DagFileProcessorManager): - seen_bundle_name: str | None = None - seen_present: set[DagFileInfo] | None = None - - def deactivate_deleted_dags(self, bundle_name: str, present: set[DagFileInfo]) -> None: - self.seen_bundle_name = bundle_name - self.seen_present = present - with configure_dag_bundles({"bundleone": bundle_path}): DagBundlesManager().sync_bundles_to_db() - manager = BackwardCompatibleManager(max_runs=1) + manager = DagFileProcessorManager(max_runs=1) manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) manager._refresh_dag_bundles({}) - assert manager.seen_bundle_name == "bundleone" - assert manager.seen_present == { - DagFileInfo( - bundle_name="bundleone", - rel_path=Path("test_dag.py"), - bundle_path=bundle_path, - ) - } + mock_dag_processing_client.reconcile.assert_called_once_with( + bundle_name="bundleone", + observed_filelocs={"test_dag.py"}, + ) @conf_vars({("core", "load_examples"): "False"}) def test_max_runs_when_no_files(self, tmp_path): @@ -847,53 +776,6 @@ def test_recently_modified_file_uses_versioned_stats_without_creating_duplicate_ assert known_file not in manager._file_stats assert versioned_file in manager._file_stats - def test_file_paths_in_queue_sorted_by_priority(self): - from airflow.models.dagbag import DagPriorityParsingRequest - - parsing_request = DagPriorityParsingRequest(relative_fileloc="file_1.py", bundle_name="dags-folder") - with create_session() as session: - session.add(parsing_request) - session.commit() - - file1 = DagFileInfo( - bundle_name="dags-folder", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER - ) - file2 = DagFileInfo( - bundle_name="dags-folder", rel_path=Path("file_2.py"), bundle_path=TEST_DAGS_FOLDER - ) - - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - manager._file_queue = OrderedDict.fromkeys([file2, file1]) - manager._queue_requested_files_for_parsing() - assert manager._file_queue == OrderedDict.fromkeys([file1, file2]) - assert manager._force_refresh_bundles == {"dags-folder"} - with create_session() as session2: - parsing_request_after = session2.get(DagPriorityParsingRequest, parsing_request.id) - assert parsing_request_after is None - - def test_parsing_requests_only_bundles_being_parsed(self, testing_dag_bundle): - """Ensure the manager only handles parsing requests for bundles being parsed in this manager""" - from airflow.models.dagbag import DagPriorityParsingRequest - - with create_session() as session: - session.add(DagPriorityParsingRequest(relative_fileloc="file_1.py", bundle_name="dags-folder")) - session.add(DagPriorityParsingRequest(relative_fileloc="file_x.py", bundle_name="testing")) - session.commit() - - file1 = DagFileInfo( - bundle_name="dags-folder", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER - ) - - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - manager._queue_requested_files_for_parsing() - assert manager._file_queue == OrderedDict.fromkeys([file1]) - with create_session() as session2: - parsing_request_after = session2.scalars(select(DagPriorityParsingRequest)).all() - assert len(parsing_request_after) == 1 - assert parsing_request_after[0].relative_fileloc == "file_x.py" - def test_queue_requested_files_for_parsing_uses_public_claim_hook(self): file1 = DagFileInfo( bundle_name="dags-folder", rel_path=Path("file_1.py"), bundle_path=TEST_DAGS_FOLDER @@ -933,10 +815,11 @@ def test_request_bundle_refresh_accepts_single_bundle_name(self): assert manager._force_refresh_bundles == {"bundleone"} @pytest.mark.usefixtures("testing_dag_bundle") - def test_scan_stale_dags(self, session): + def test_scan_stale_dags(self, mock_dag_processing_client): """ - Ensure that DAGs are marked inactive when the file is parsed but the - DagModel.last_parsed_time is not updated. + Ensure ``_scan_stale_dags`` forwards the parsed files to the DAG Processing API so the + server can deactivate DAGs whose files are no longer present. The actual stale-marking + now happens server-side. """ manager = DagFileProcessorManager( max_runs=1, @@ -951,21 +834,13 @@ def test_scan_stale_dags(self, session): rel_path=Path("test_example_bash_operator.py"), bundle_path=TEST_DAGS_FOLDER, ) - dagbag = DagBag( - test_dag_path.absolute_path, - include_examples=False, - bundle_path=test_dag_path.bundle_path, - ) - - # Add stale DAG to the DB - dag = dagbag.get_dag("test_example_bash_operator") - sync_dag_to_db(dag, session=session) # Add DAG to the file_parsing_stats + finish_time = timezone.utcnow() + timedelta(hours=1) stat = DagFileStat( num_dags=1, import_errors=0, - last_finish_time=timezone.utcnow() + timedelta(hours=1), + last_finish_time=finish_time, last_duration=1, run_count=1, last_num_of_db_queries=1, @@ -973,72 +848,19 @@ def test_scan_stale_dags(self, session): manager._files = [test_dag_path] manager._file_stats[test_dag_path] = stat - active_dag_count = session.scalar( - select(func.count(DagModel.dag_id)).where( - ~DagModel.is_stale, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) - ) - assert active_dag_count == 1 - + # Force the cleanup interval so the scan actually runs this cycle. + manager._last_deactivate_stale_dags_time = time.monotonic() - (manager.parsing_cleanup_interval + 1) manager._scan_stale_dags() - active_dag_count = session.scalar( - select(func.count(DagModel.dag_id)).where( - ~DagModel.is_stale, - DagModel.relative_fileloc == str(test_dag_path.rel_path), - DagModel.bundle_name == test_dag_path.bundle_name, - ) - ) - assert active_dag_count == 0 - - serialized_dag_count = session.scalar( - select(func.count(SerializedDagModel.dag_id)).where(SerializedDagModel.dag_id == dag.dag_id) - ) - # Deactivating the DagModel should not delete the SerializedDagModel - # SerializedDagModel gives history about Dags - assert serialized_dag_count == 1 - - @pytest.mark.usefixtures("testing_dag_bundle") - def test_deactivate_stale_dags_marks_dags_in_inactive_bundles(self, session): - """Dags whose bundle is no longer active should be marked stale even without a parse signal.""" - session.add(DagBundleModel(name="gone-bundle")) - session.flush() - session.execute( - DagBundleModel.__table__.update().where(DagBundleModel.name == "gone-bundle").values(active=False) - ) - session.add( - DagModel( - dag_id="dag_in_inactive_bundle", - bundle_name="gone-bundle", - relative_fileloc="some_file.py", - last_parsed_time=timezone.utcnow(), - is_stale=False, - ) - ) - session.add( - DagModel( - dag_id="dag_in_active_bundle", - bundle_name="testing", - relative_fileloc="other_file.py", - last_parsed_time=timezone.utcnow(), - is_stale=False, - ) - ) - session.flush() - - manager = DagFileProcessorManager(max_runs=1, processor_timeout=10 * 60) - manager.deactivate_stale_dags(last_parsed={}) - - is_stale_by_dag = dict( - session.execute( - select(DagModel.dag_id, DagModel.is_stale).where( - DagModel.dag_id.in_(["dag_in_inactive_bundle", "dag_in_active_bundle"]) - ) - ).all() - ) - assert is_stale_by_dag == {"dag_in_inactive_bundle": True, "dag_in_active_bundle": False} + mock_dag_processing_client.deactivate_stale_dags.assert_called_once() + call_kwargs = mock_dag_processing_client.deactivate_stale_dags.call_args.kwargs + assert call_kwargs["last_parsed"] == [ + { + "bundle_name": "testing", + "relative_fileloc": "test_example_bash_operator.py", + "last_finish_time": finish_time.isoformat(), + } + ] @mock.patch("airflow.dag_processing.manager.BundleUsageTrackingManager") def test_cleanup_stale_bundle_versions_interval(self, mock_bundle_manager): @@ -1119,24 +941,7 @@ def test_kill_timed_out_processors_no_kill(self): manager._kill_timed_out_processors() mock_kill.assert_not_called() - def test_handle_parsing_result_provides_its_own_session_when_caller_omits(self): - """``handle_parsing_result`` is wrapped in ``@provide_session`` so subclasses overriding it can run without a caller-supplied session.""" - manager = DagFileProcessorManager(max_runs=1) - file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) - manager._file_stats[file] = DagFileStat() - manager._bundle_versions["testing"] = "v1" - - processor, _ = self.mock_processor(start_time=time.monotonic() - 1) - processor.had_callbacks = False - processor.parsing_result = DagFileParsingResult(fileloc="abc.txt", serialized_dags=[]) - - with mock.patch.object(manager, "persist_parsing_result") as mock_persist: - manager.handle_parsing_result(file, processor) - - mock_persist.assert_called_once() - assert mock_persist.call_args.kwargs["session"] is not None - - def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self, session): + def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self): """Persist errors should throttle retries without claiming persistence succeeded.""" manager = DagFileProcessorManager(max_runs=1) file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) @@ -1150,7 +955,7 @@ def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self, se processor.parsing_result = DagFileParsingResult(fileloc="abc.txt", serialized_dags=[]) with mock.patch.object(manager, "persist_parsing_result", side_effect=RuntimeError("boom")): - manager.handle_parsing_result(file, processor, session=session) + manager.handle_parsing_result(file, processor) assert manager._file_stats[file] is not original_stat assert manager._file_stats[file].num_dags == 0 @@ -1160,7 +965,7 @@ def test_handle_parsing_result_throttles_retry_when_first_persist_fails(self, se assert manager._file_stats[file].last_duration is not None assert manager.processed_recently(timezone.utcnow(), file) is True - def test_handle_parsing_result_updates_stats_after_successful_persist(self, session): + def test_handle_parsing_result_updates_stats_after_successful_persist(self): manager = DagFileProcessorManager(max_runs=1) file = DagFileInfo(bundle_name="testing", rel_path=Path("abc.txt"), bundle_path=TEST_DAGS_FOLDER) original_stat = DagFileStat( @@ -1179,7 +984,7 @@ def test_handle_parsing_result_updates_stats_after_successful_persist(self, sess processor.parsing_result = DagFileParsingResult(fileloc="abc.txt", serialized_dags=[]) with mock.patch.object(manager, "persist_parsing_result") as mock_persist: - manager.handle_parsing_result(file, processor, session=session) + manager.handle_parsing_result(file, processor) mock_persist.assert_called_once_with( bundle_name="testing", @@ -1188,7 +993,6 @@ def test_handle_parsing_result_updates_stats_after_successful_persist(self, sess parsing_result=processor.parsing_result, run_duration=mock.ANY, relative_fileloc="abc.txt", - session=session, ) assert manager._file_stats[file] is not original_stat assert manager._file_stats[file].run_count == 4 @@ -1311,23 +1115,15 @@ def test_dag_with_system_exit(self, configure_testing_dag_bundle): """ Test to check that a DAG with a system.exit() doesn't break the scheduler. """ - dag_id = "exit_test_dag" dag_directory = TEST_DAG_FOLDER.parent / "dags_with_system_exit" - # Delete the one valid DAG/SerializedDAG, and check that it gets re-created - clear_db_dags() - clear_db_serialized_dags() - with configure_testing_dag_bundle(dag_directory): manager = DagFileProcessorManager(max_runs=1) manager.run() - # Three files in folder should be processed + # The system-exit DAG must not abort the loop: all three files still get processed. assert sum(stat.run_count for stat in manager._file_stats.values()) == 3 - with create_session() as session: - assert session.get(DagModel, dag_id) is not None - @conf_vars({("core", "load_examples"): "False"}) @mock.patch("airflow.dag_processing.manager.stats.timing") def test_send_file_processing_statsd_timing( @@ -1361,84 +1157,6 @@ def test_send_file_processing_statsd_timing( ) @pytest.mark.usefixtures("testing_dag_bundle") - def test_refresh_dags_dir_doesnt_delete_zipped_dags( - self, tmp_path, session, configure_testing_dag_bundle, test_zip_path - ): - """Test DagFileProcessorManager._refresh_dag_dir method""" - dagbag = DagBag(dag_folder=tmp_path, include_examples=False) - dagbag.process_file(test_zip_path) - dag = dagbag.get_dag("test_zip_dag") - sync_dag_to_db(dag) - - with configure_testing_dag_bundle(test_zip_path): - manager = DagFileProcessorManager(max_runs=1) - manager.run() - - # Assert dag not deleted in SDM - assert SerializedDagModel.has_dag("test_zip_dag") - # assert code not deleted - assert DagCode.has_dag(dag.dag_id) - # assert dag still active - assert session.get(DagModel, dag.dag_id).is_stale is False - - @pytest.mark.usefixtures("testing_dag_bundle") - def test_refresh_dags_dir_deactivates_deleted_zipped_dags( - self, session, tmp_path, configure_testing_dag_bundle, test_zip_path - ): - """Test DagFileProcessorManager._refresh_dag_dir method""" - dag_id = "test_zip_dag" - filename = "test_zip.zip" - source_location = test_zip_path - bundle_path = Path(tmp_path, "test_refresh_dags_dir_deactivates_deleted_zipped_dags") - bundle_path.mkdir(exist_ok=True) - zip_dag_path = bundle_path / filename - shutil.copy(source_location, zip_dag_path) - - with configure_testing_dag_bundle(bundle_path): - session.commit() - manager = DagFileProcessorManager(max_runs=1) - manager.run() - - assert SerializedDagModel.has_dag(dag_id) - assert DagCode.has_dag(dag_id) - assert DagVersion.get_latest_version(dag_id) - dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) - assert dag.is_stale is False - - os.remove(zip_dag_path) - - manager.run() - - assert SerializedDagModel.has_dag(dag_id) - assert DagCode.has_dag(dag_id) - assert DagVersion.get_latest_version(dag_id) - dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id)) - assert dag.is_stale is True - - def test_deactivate_deleted_dags(self, dag_maker, session): - with dag_maker("test_dag1") as dag1: - dag1.relative_fileloc = "test_dag1.py" - with dag_maker("test_dag2") as dag2: - dag2.relative_fileloc = "test_dag2.py" - dag_maker.sync_dagbag_to_db() - - active_files = [ - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag1.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - # Mimic that the test_dag2.py file is deleted - ] - - manager = DagFileProcessorManager(max_runs=1) - manager.deactivate_deleted_dags("dag_maker", set(active_files)) - - # The DAG from test_dag1.py is still active - assert session.get(DagModel, "test_dag1").is_stale is False - # and the DAG from test_dag2.py is deactivated - assert session.get(DagModel, "test_dag2").is_stale is True - @pytest.mark.parametrize( ("rel_filelocs", "expected_return", "expected_dag1_stale", "expected_dag2_stale"), [ @@ -1478,231 +1196,6 @@ def test_deactivate_deleted_dags_return_value( assert session.get(DagModel, "test_dag1").is_stale is expected_dag1_stale assert session.get(DagModel, "test_dag2").is_stale is expected_dag2_stale - @pytest.mark.parametrize( - ("active_files", "should_call_cleanup"), - [ - pytest.param( - [ - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag1.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - # test_dag2.py is deleted - ], - True, # Should call cleanup - id="dags_deactivated", - ), - pytest.param( - [ - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag1.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - DagFileInfo( - bundle_name="dag_maker", - rel_path=Path("test_dag2.py"), - bundle_path=TEST_DAGS_FOLDER, - ), - ], - False, # Should NOT call cleanup - id="no_dags_deactivated", - ), - ], - ) - @mock.patch("airflow.dag_processing.manager.remove_references_to_deleted_dags") - def test_manager_deactivate_deleted_dags_cleanup_behavior( - self, mock_remove_references, dag_maker, session, active_files, should_call_cleanup - ): - """Test that manager conditionally calls remove_references_to_deleted_dags based on whether DAGs were deactivated.""" - with dag_maker("test_dag1") as dag1: - dag1.relative_fileloc = "test_dag1.py" - with dag_maker("test_dag2") as dag2: - dag2.relative_fileloc = "test_dag2.py" - dag_maker.sync_dagbag_to_db() - - manager = DagFileProcessorManager(max_runs=1) - manager.deactivate_deleted_dags("dag_maker", set(active_files)) - - if should_call_cleanup: - mock_remove_references.assert_called_once() - else: - mock_remove_references.assert_not_called() - - @conf_vars({("core", "load_examples"): "False"}) - def test_fetch_callbacks_from_database(self, configure_testing_dag_bundle): - """Test _fetch_callbacks_from_db returns callbacks ordered by priority_weight desc.""" - - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - callback1 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="123", - ) - callback2 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="456", - ) - - with create_session() as session: - session.add(DbCallbackRequest(callback=callback1, priority_weight=11)) - session.add(DbCallbackRequest(callback=callback2, priority_weight=10)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - - with create_session() as session: - callbacks = manager._fetch_callbacks_from_db(session=session) - - # Should return callbacks ordered by priority_weight desc (highest first) - assert callbacks[0].run_id == "123" - assert callbacks[1].run_id == "456" - - assert len(session.scalars(select(DbCallbackRequest)).all()) == 0 - - @conf_vars( - { - ("dag_processor", "max_callbacks_per_loop"): "2", - ("core", "load_examples"): "False", - } - ) - def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_testing_dag_bundle): - """Test DagFileProcessorManager.fetch_callbacks method.""" - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - with create_session() as session: - for i in range(5): - callback = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id=str(i), - ) - session.add(DbCallbackRequest(callback=callback, priority_weight=i)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - - with create_session() as session: - manager.run() - assert len(session.scalars(select(DbCallbackRequest)).all()) == 3 - - with create_session() as session: - manager.run() - assert len(session.scalars(select(DbCallbackRequest)).all()) == 1 - - @conf_vars({("core", "load_examples"): "False"}) - def test_fetch_callbacks_ignores_other_bundles(self, configure_testing_dag_bundle): - """Ensure callbacks for bundles not owned by current dag processor manager are ignored and not deleted.""" - - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - # Create two callbacks: one for the active 'testing' bundle and one for a different bundle - matching = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="match", - ) - non_matching = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="other-bundle", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="no-match", - ) - - with create_session() as session: - session.add(DbCallbackRequest(callback=matching, priority_weight=100)) - session.add(DbCallbackRequest(callback=non_matching, priority_weight=200)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - - with create_session() as session: - callbacks = manager._fetch_callbacks_from_db(session=session) - - # Only the matching callback should be returned - assert [c.run_id for c in callbacks] == ["match"] - - # The non-matching callback should remain in the DB - remaining = session.scalars(select(DbCallbackRequest)).all() - assert len(remaining) == 1 - # Decode remaining request and verify it's for the other bundle - remaining_req = remaining[0].get_callback_request() - assert remaining_req.bundle_name == "other-bundle" - - @conf_vars( - { - ("dag_processor", "max_callbacks_per_loop"): "2", - ("core", "load_examples"): "False", - } - ) - def test_fetch_callbacks_filters_by_bundle_before_limit(self, configure_testing_dag_bundle): - dag_filepath = TEST_DAG_FOLDER / "test_on_failure_callback_dag.py" - - matching = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="testing", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="match", - ) - non_matching_1 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="other-bundle-a", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="no-match-1", - ) - non_matching_2 = DagCallbackRequest( - dag_id="test_start_date_scheduling", - bundle_name="other-bundle-b", - bundle_version=None, - filepath="test_on_failure_callback_dag.py", - is_failure_callback=True, - run_id="no-match-2", - ) - - with create_session() as session: - session.add(DbCallbackRequest(callback=non_matching_1, priority_weight=300)) - session.add(DbCallbackRequest(callback=non_matching_2, priority_weight=200)) - session.add(DbCallbackRequest(callback=matching, priority_weight=100)) - - with configure_testing_dag_bundle(dag_filepath): - manager = DagFileProcessorManager(max_runs=1) - manager._dag_bundles = list(DagBundlesManager().get_all_dag_bundles()) - - with create_session() as session: - callbacks = manager._fetch_callbacks_from_db(session=session) - - assert [c.run_id for c in callbacks] == ["match"] - - remaining = session.scalars(select(DbCallbackRequest)).all() - assert len(remaining) == 2 - assert {callback.bundle_name for callback in remaining} == { - "other-bundle-a", - "other-bundle-b", - } - @mock.patch.object(DagFileProcessorManager, "_get_logger_for_dag_file") def test_callback_queue(self, mock_get_logger, configure_testing_dag_bundle): mock_logger = MagicMock() @@ -1981,29 +1474,6 @@ def test_add_callback_skips_when_bundle_unconfigured(self, mock_bundle_manager): assert not manager._callback_to_execute - def test_fetch_callbacks_delegates_to_private_method(self): - manager = DagFileProcessorManager(max_runs=1) - expected: list = [mock.sentinel.callback] - with mock.patch.object(manager, "_fetch_callbacks_from_db", return_value=expected) as private: - assert manager.fetch_callbacks() is expected - private.assert_called_once_with() - - def test_dag_with_assets(self, session, configure_testing_dag_bundle): - """'Integration' test to ensure that the assets get parsed and stored correctly for parsed dags.""" - test_dag_path = str(TEST_DAG_FOLDER / "test_assets.py") - - with configure_testing_dag_bundle(test_dag_path): - manager = DagFileProcessorManager( - max_runs=1, - processor_timeout=365 * 86_400, - ) - manager.run() - - dag_model = session.get(DagModel, ("dag_with_skip_task")) - assert dag_model.task_outlet_asset_references == [ - TaskOutletAssetReference(asset_id=mock.ANY, dag_id="dag_with_skip_task", task_id="skip_task") - ] - def test_bundles_are_refreshed(self): """ Ensure bundles are refreshed by the manager, when necessary. @@ -2174,7 +1644,7 @@ def test_bundle_force_refresh(self): manager._refresh_dag_bundles({}) assert bundleone.refresh.call_count == 2 # forced refresh - def test_bundles_versions_are_stored(self, session): + def test_bundles_versions_are_stored(self, mock_dag_processing_client): config = [ { "name": "bundleone", @@ -2198,9 +1668,10 @@ def test_bundles_versions_are_stored(self, session): manager = DagFileProcessorManager(max_runs=1) manager.run() - with create_session() as session: - model = session.get(DagBundleModel, "bundleone") - assert model.version == "123" + # The refreshed bundle version is persisted via the DAG Processing API. + mock_dag_processing_client.update_bundle_state.assert_any_call( + "bundleone", last_refreshed=mock.ANY, version="123" + ) def test_non_versioned_bundle_get_version_not_called(self): config = [ @@ -2408,15 +1879,6 @@ def test_after_run_runs_when_parsing_loop_raises(self, tmp_path, configure_testi manager.run() after_run_mock.assert_called_once_with() - def test_prepare_server_process_context_can_be_skipped(self, tmp_path, configure_testing_dag_bundle): - """API-backed subclasses can skip server-context setup without losing selector/stats init.""" - with configure_testing_dag_bundle(tmp_path): - manager = DagFileProcessorManager(max_runs=1) - with mock.patch.object(manager, "prepare_server_process_context") as server_ctx_mock: - manager.run() - server_ctx_mock.assert_called_once_with() - assert manager.selector is not None - def test_prepare_bundles_can_be_overridden_without_sync(self, tmp_path, configure_testing_dag_bundle): """Subclasses can override `prepare_bundles` to skip `sync_bundles` (AIP-92 seam).""" with configure_testing_dag_bundle(tmp_path): @@ -2429,29 +1891,19 @@ def test_prepare_bundles_can_be_overridden_without_sync(self, tmp_path, configur sync_mock.assert_not_called() assert [b.name for b in manager._dag_bundles] == ["testing"] - def test_purge_inactive_dag_warnings_delegates_to_dagwarning(self): - """Default `purge_inactive_dag_warnings` calls `DagWarning.purge_inactive_dag_warnings`.""" - manager = DagFileProcessorManager(max_runs=1) - with mock.patch( - "airflow.dag_processing.manager.DagWarning.purge_inactive_dag_warnings" - ) as purge_mock: - manager.purge_inactive_dag_warnings() - purge_mock.assert_called_once_with() - def test_run_parsing_loop_uses_overridable_purge(self, tmp_path, configure_testing_dag_bundle): """`_run_parsing_loop` calls the overridable `purge_inactive_dag_warnings` seam.""" with configure_testing_dag_bundle(tmp_path): manager = DagFileProcessorManager(max_runs=1) - with ( - mock.patch.object(manager, "purge_inactive_dag_warnings") as purge_mock, - mock.patch( - "airflow.dag_processing.manager.DagWarning.purge_inactive_dag_warnings" - ) as direct_mock, - ): + with mock.patch.object(manager, "purge_inactive_dag_warnings") as purge_mock: manager.run() purge_mock.assert_called() - direct_mock.assert_not_called() + # min_file_process_interval=0 makes each file unconditionally eligible for re-parsing, so the + # three successive runs below don't hinge on the file's mtime exceeding its last-parsed time + # (a filesystem-granularity race that, under load, left a run waiting out the default 30s + # interval and blew the per-test timeout). + @conf_vars({("dag_processor", "min_file_process_interval"): "0"}) @mock.patch("airflow.dag_processing.manager.stats.gauge") def test_stats_total_parse_time(self, statsd_gauge_mock, tmp_path, configure_testing_dag_bundle): key = "dag_processing.total_parse_time" @@ -2478,80 +1930,8 @@ def test_stats_total_parse_time(self, statsd_gauge_mock, tmp_path, configure_tes assert len(gauge_values[key]) == 1 assert gauge_values[key][0] >= 1e-4 - dag_path.touch() # make the loop run faster gauge_values.clear() - # --- get_bundle_state / update_bundle_state --- - - def test_get_bundle_state_returns_none_for_missing_bundle(self): - manager = DagFileProcessorManager(max_runs=1) - assert manager.get_bundle_state("nonexistent_bundle") is None - - def test_get_bundle_state_returns_correct_state(self, session): - bundle_name = "test_state_bundle" - refreshed_at = timezone.datetime(2024, 1, 15, 12, 0, 0) - model = DagBundleModel(name=bundle_name, version="v1") - model.last_refreshed = refreshed_at - session.add(model) - session.commit() - - manager = DagFileProcessorManager(max_runs=1) - state = manager.get_bundle_state(bundle_name) - - assert state == BundleState(last_refreshed=refreshed_at, version="v1") - - def test_get_bundle_state_null_fields(self, session): - bundle_name = "test_null_state_bundle" - session.add(DagBundleModel(name=bundle_name)) - session.commit() - - manager = DagFileProcessorManager(max_runs=1) - state = manager.get_bundle_state(bundle_name) - - assert state == BundleState(last_refreshed=None, version=None) - - def test_update_bundle_state_sets_last_refreshed(self, session): - bundle_name = "test_update_bundle" - session.add(DagBundleModel(name=bundle_name)) - session.commit() - - refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0) - manager = DagFileProcessorManager(max_runs=1) - manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, version=None) - - session.expire_all() - model = session.get(DagBundleModel, bundle_name) - assert model.last_refreshed == refreshed_at - assert model.version is None - - def test_update_bundle_state_sets_version(self, session): - bundle_name = "test_update_version_bundle" - session.add(DagBundleModel(name=bundle_name)) - session.commit() - - refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0) - manager = DagFileProcessorManager(max_runs=1) - manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, version="abc123") - - session.expire_all() - model = session.get(DagBundleModel, bundle_name) - assert model.last_refreshed == refreshed_at - assert model.version == "abc123" - - def test_update_bundle_state_does_not_overwrite_version_when_none(self, session): - bundle_name = "test_preserve_version_bundle" - session.add(DagBundleModel(name=bundle_name, version="keep_me")) - session.commit() - - refreshed_at = timezone.datetime(2024, 6, 1, 8, 0, 0) - manager = DagFileProcessorManager(max_runs=1) - manager.update_bundle_state(bundle_name, last_refreshed=refreshed_at, version=None) - - session.expire_all() - model = session.get(DagBundleModel, bundle_name) - assert model.last_refreshed == refreshed_at - assert model.version == "keep_me" - def _make_refresh_bundle(self, *, supports_versioning=False, current_version=None): bundle = MagicMock(spec=BaseDagBundle) bundle.name = "mock_bundle" @@ -2572,14 +1952,13 @@ def _refresh_with_mocked_state(self, manager, bundle, initial_state): """ manager._dag_bundles = [bundle] manager._force_refresh_bundles = set() + manager._dag_processing_client = mock.MagicMock() mock_get = mock.patch.object(manager, "get_bundle_state", return_value=initial_state) mock_update = mock.patch.object(manager, "update_bundle_state") with ( mock_get as patched_get, mock_update as patched_update, mock.patch.object(manager, "_find_files_in_bundle", return_value=[]), - mock.patch.object(manager, "deactivate_deleted_dags"), - mock.patch.object(manager, "clear_orphaned_import_errors"), mock.patch.object(manager, "handle_removed_files"), mock.patch.object(manager, "_resort_file_queue"), mock.patch.object(manager, "_add_new_files_to_queue"), @@ -2640,6 +2019,7 @@ def test_refresh_dag_bundles_versioned_version_unchanged_persist_failure(self): manager._bundle_versions["mock_bundle"] = "v1" manager._dag_bundles = [bundle] manager._force_refresh_bundles = set() + manager._dag_processing_client = mock.MagicMock() known_files: dict[str, set[DagFileInfo]] = {} with ( @@ -2648,8 +2028,6 @@ def test_refresh_dag_bundles_versioned_version_unchanged_persist_failure(self): ), mock.patch.object(manager, "update_bundle_state", side_effect=Exception("DB error")), mock.patch.object(manager, "_find_files_in_bundle", return_value=[]) as mock_find, - mock.patch.object(manager, "deactivate_deleted_dags"), - mock.patch.object(manager, "clear_orphaned_import_errors"), mock.patch.object(manager, "handle_removed_files"), mock.patch.object(manager, "_resort_file_queue"), mock.patch.object(manager, "_add_new_files_to_queue"), @@ -2703,6 +2081,7 @@ def test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self) manager = DagFileProcessorManager(max_runs=1) bundle = self._make_refresh_bundle() manager._dag_bundles = [bundle] + manager._dag_processing_client = mock.MagicMock() known_files: dict[str, set[DagFileInfo]] = {} with ( @@ -2711,8 +2090,6 @@ def test_refresh_dag_bundles_update_bundle_state_failure_still_scans_files(self) ), mock.patch.object(manager, "update_bundle_state", side_effect=Exception("API error")), mock.patch.object(manager, "_find_files_in_bundle", return_value=[]), - mock.patch.object(manager, "deactivate_deleted_dags"), - mock.patch.object(manager, "clear_orphaned_import_errors"), mock.patch.object(manager, "handle_removed_files"), mock.patch.object(manager, "_resort_file_queue"), mock.patch.object(manager, "_add_new_files_to_queue"), diff --git a/airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py b/airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py new file mode 100644 index 0000000000000..02c5eaab1e797 --- /dev/null +++ b/airflow-core/tests/unit/dag_processing/test_manager_api_persistence.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Tests for the API-backed persistence of DagFileProcessorManager (AIP-92). + +The DAG processor never reads or writes the metadata database directly; every persistence and +metadata operation is routed through the DAG Processing API client, and parse-time metadata +reads go through the remote Execution API. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from unittest import mock + +from airflow.dag_processing.manager import ( + BundleState, + DagFileInfo, + DagFileProcessorManager, + _api_token, + _dag_processing_api_server_url, +) + +from tests_common.test_utils.config import conf_vars + +MGR = "airflow.dag_processing.manager" + + +# --- DAG Processing API URL resolution + client construction --- + + +@conf_vars({("api", "base_url"): "http://my-host:9999/"}) +def test_dag_processing_api_server_url_derives_sibling_mount(): + assert _dag_processing_api_server_url() == "http://my-host:9999/dag-processing" + + +@conf_vars({("core", "dag_processing_api_server_url"): "http://explicit/dag-processing"}) +def test_dag_processing_api_server_url_prefers_explicit(): + assert _dag_processing_api_server_url() == "http://explicit/dag-processing" + + +def test_api_token_reads_externally_provisioned_file(tmp_path): + # The processor only carries a token a trusted component wrote; it never mints one. + token_file = tmp_path / "token" + token_file.write_text("provisioned-token\n") + with conf_vars({("dag_processor", "api_token_path"): str(token_file)}): + assert _api_token() == "provisioned-token" + + +def test_api_token_none_when_unset(): + with conf_vars({("dag_processor", "api_token_path"): ""}): + assert _api_token() is None + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_client_built_with_derived_default(mock_client_cls): + # No explicit URL -> always built, pointing at the derived /dag-processing mount. + manager = DagFileProcessorManager(max_runs=1) + assert manager._dag_processing_client is mock_client_cls.return_value + (url,) = mock_client_cls.call_args.args + assert url.endswith("/dag-processing") + + +@conf_vars( + { + ("core", "load_examples"): "False", + ("core", "dag_processing_api_server_url"): "http://api:8080/dag-processing", + } +) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_client_built_with_explicit_url(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + assert manager._dag_processing_client is mock_client_cls.return_value + # The client is built with the explicit URL (a self-signed token kwarg is also passed). + assert mock_client_cls.call_args.args[0] == "http://api:8080/dag-processing" + + +# --- every persistence/metadata method delegates to the client --- + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_persist_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + + manager.persist_parsing_result( + bundle_name="b1", + bundle_version="v1", + version_data=None, + parsing_result=mock.MagicMock(), + run_duration=1.0, + relative_fileloc="dags/a.py", + ) + + client.persist_parsing_result.assert_called_once() + kwargs = client.persist_parsing_result.call_args.kwargs + assert kwargs["bundle_name"] == "b1" + assert kwargs["relative_fileloc"] == "dags/a.py" + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_get_bundle_state_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.get_bundle_state.return_value = {"last_refreshed": None, "version": "v9"} + + state = manager.get_bundle_state("b1") + + client.get_bundle_state.assert_called_once_with("b1") + assert state == BundleState(last_refreshed=None, version="v9") + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_get_bundle_state_returns_none_from_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.get_bundle_state.return_value = None + + assert manager.get_bundle_state("b1") is None + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_update_bundle_state_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + ts = datetime.now(timezone.utc) + + manager.update_bundle_state("b1", last_refreshed=ts, version=None) + + client.update_bundle_state.assert_called_once_with("b1", last_refreshed=ts, version=None) + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_sync_bundles_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + + manager.sync_bundles() + + client.sync_bundles.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_deactivate_stale_dags_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + file_info = DagFileInfo(rel_path=Path("a.py"), bundle_name="b1") + + manager.deactivate_stale_dags(last_parsed={file_info: datetime.now(timezone.utc)}) + + client.deactivate_stale_dags.assert_called_once() + entries = client.deactivate_stale_dags.call_args.kwargs["last_parsed"] + assert entries[0]["bundle_name"] == "b1" + assert entries[0]["relative_fileloc"] == "a.py" + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_purge_inactive_dag_warnings_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + + manager.purge_inactive_dag_warnings() + + client.purge_inactive_dag_warnings.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_claim_priority_files_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.claim_priority_files.return_value = [{"bundle_name": "b1", "relative_fileloc": "a.py"}] + bundle = mock.MagicMock() + bundle.name = "b1" + bundle.path = Path("/bundles/b1") + manager._dag_bundles = [bundle] + + files = manager.claim_priority_files() + + client.claim_priority_files.assert_called_once_with(["b1"]) + assert len(files) == 1 + assert files[0].bundle_name == "b1" + assert str(files[0].rel_path) == "a.py" + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_fetch_callbacks_delegates_to_client(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + client = manager._dag_processing_client + client.fetch_callbacks.return_value = ["callback-request"] + + result = manager.fetch_callbacks() + + client.fetch_callbacks.assert_called_once() + assert result == ["callback-request"] + + +# --- Job lifecycle runs through the API --- + + +def test_run_dag_processor_job_uses_api(): + from airflow.cli.commands.dag_processor_command import _run_dag_processor_job + + client = mock.MagicMock() + client.register_job.return_value = 99 + processor = mock.MagicMock() + processor._dag_processing_client = client + job_runner = mock.MagicMock(job_type="DagProcessorJob", processor=processor) + job_runner.job.heartrate = 0.0 # no throttle so the first heartbeat fires immediately + + _run_dag_processor_job(job_runner) + + client.register_job.assert_called_once_with("DagProcessorJob") + job_runner._execute.assert_called_once() + client.complete_job.assert_called_once_with(99, state="success") + # the heartbeat hook routes through the API client + processor.heartbeat() + client.job_heartbeat.assert_called_with(99) + + +# --- parse-time metadata reads go to the remote Execution API --- + + +@conf_vars( + { + ("core", "load_examples"): "False", + ("core", "execution_api_server_url"): "http://exec:8080/execution", + } +) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_parse_time_client_is_remote(mock_client_cls): + manager = DagFileProcessorManager(max_runs=1) + with ( + mock.patch("airflow.dag_processing.manager._api_token", return_value="tok"), + mock.patch("airflow.dag_processing.manager.Client") as client_cls, + ): + _ = manager.client + + client_cls.assert_called_once() + kwargs = client_cls.call_args.kwargs + assert kwargs["base_url"] == "http://exec:8080/execution" + assert kwargs["token"] == "tok" + assert kwargs["dry_run"] is False + + +# --- bundle-init credentials resolve through the Execution API (no metadata DB) --- + + +def test_secrets_comms_resolves_connection_via_execution_api(): + from airflow.dag_processing.manager import _DagProcessorSecretsComms + from airflow.sdk.execution_time.comms import GetConnection + + client = mock.MagicMock() + comms = _DagProcessorSecretsComms(client) + with mock.patch( + "airflow.dag_processing.manager.handle_get_connection", + return_value=("conn-result", {}), + ) as handler: + result = comms.send(GetConnection(conn_id="git_default")) + + handler.assert_called_once() + assert handler.call_args.args[0] is client # routed through the Execution API client + assert result == "conn-result" + + +def test_secrets_comms_ignores_unsupported_messages(): + from airflow.dag_processing.manager import _DagProcessorSecretsComms + + comms = _DagProcessorSecretsComms(mock.MagicMock()) + # e.g. a MaskSecret emitted while masking a fetched secret: no-op, never touches the client. + assert comms.send(object()) is None + + +@conf_vars({("core", "load_examples"): "False"}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_before_run_installs_secrets_comms(mock_client_cls): + from airflow.dag_processing.manager import _DagProcessorSecretsComms + from airflow.sdk.execution_time import task_runner + + manager = DagFileProcessorManager(max_runs=1) + manager.__dict__["client"] = mock.MagicMock() # pre-populate the cached_property + + sentinel = object() + old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel) + try: + manager._setup_secrets_comms() + assert isinstance(task_runner.SUPERVISOR_COMMS, _DagProcessorSecretsComms) + finally: + if old is sentinel: + if hasattr(task_runner, "SUPERVISOR_COMMS"): + del task_runner.SUPERVISOR_COMMS + else: + task_runner.SUPERVISOR_COMMS = old diff --git a/airflow-core/tests/unit/dag_processing/test_no_db_mode.py b/airflow-core/tests/unit/dag_processing/test_no_db_mode.py new file mode 100644 index 0000000000000..7712d446fa3bb --- /dev/null +++ b/airflow-core/tests/unit/dag_processing/test_no_db_mode.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""No-DB-mode harness for the DAG processor (AIP-92). + +The goal is a DAG processor that runs with zero metadata-DB access (all DB I/O +forwarded over the API). `forbid_db_access` makes any metadata-DB connection +raise, so tests can assert a code path is DB-free. Each phase that moves a DB +touchpoint behind the API flips one of the strict-xfail gates below to passing. +""" + +from __future__ import annotations + +import contextlib +from datetime import datetime, timezone +from unittest import mock + +import pytest + +from airflow.dag_processing.manager import DagFileProcessorManager + +from tests_common.test_utils.config import conf_vars + +MGR = "airflow.dag_processing.manager" +API_URL = "http://localhost:8080/dag-processing" + +# These tests patch ``settings.engine`` to forbid connections, so they need a real engine to +# patch. In the Non-DB test environment ``settings.engine`` is None, so run them as DB tests. +pytestmark = pytest.mark.db_test + + +@contextlib.contextmanager +def forbid_db_access(): + """Raise ``AssertionError`` if any metadata-DB connection is opened in the block.""" + from airflow import settings + + def _forbid(*args, **kwargs): + raise AssertionError("Metadata DB access is forbidden in no-DB mode") + + with ( + mock.patch.object(settings.engine, "connect", side_effect=_forbid), + mock.patch.object(settings.engine, "raw_connection", side_effect=_forbid), + ): + yield + + +@conf_vars({("core", "load_examples"): "False"}) +def test_forbid_db_access_blocks_real_queries(): + """The guard itself works: a real query under it raises.""" + from sqlalchemy import text + + from airflow.utils.session import create_session + + with forbid_db_access(), pytest.raises(AssertionError, match="forbidden"): + with create_session() as session: + session.execute(text("SELECT 1")).all() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_bundle_state_is_db_free_when_api_configured(mock_client_cls): + """Reading and updating bundle state go through the API, opening no DB connection.""" + client = mock_client_cls.return_value + client.get_bundle_state.return_value = None + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + assert manager.get_bundle_state("any-bundle") is None + manager.update_bundle_state("any-bundle", last_refreshed=datetime.now(timezone.utc), version="v1") + client.get_bundle_state.assert_called_once_with("any-bundle") + client.update_bundle_state.assert_called_once() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_sync_bundles_is_db_free_when_api_configured(mock_client_cls): + """Startup bundle sync goes through the API, opening no DB connection.""" + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + manager.sync_bundles() + mock_client_cls.return_value.sync_bundles.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_stale_sweep_is_db_free_when_api_configured(mock_client_cls): + """Time-based stale-dag deactivation and warning purge go through the API, no DB.""" + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + manager.deactivate_stale_dags(last_parsed={}) + manager.purge_inactive_dag_warnings() + client = mock_client_cls.return_value + client.deactivate_stale_dags.assert_called_once() + client.purge_inactive_dag_warnings.assert_called_once_with() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_priority_claim_is_db_free_when_api_configured(mock_client_cls): + """Claiming priority parse requests goes through the API, opening no DB connection.""" + mock_client_cls.return_value.claim_priority_files.return_value = [] + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + assert manager.claim_priority_files() == [] + mock_client_cls.return_value.claim_priority_files.assert_called_once() + + +@conf_vars({("core", "load_examples"): "False", ("core", "dag_processing_api_server_url"): API_URL}) +@mock.patch(f"{MGR}.DagProcessingApiClient") +def test_callback_claim_is_db_free_when_api_configured(mock_client_cls): + """Fetching callbacks goes through the API, opening no DB connection.""" + mock_client_cls.return_value.fetch_callbacks.return_value = [] + manager = DagFileProcessorManager(max_runs=1) + with forbid_db_access(): + assert manager.fetch_callbacks() == [] + mock_client_cls.return_value.fetch_callbacks.assert_called_once() diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ef92372f57378..4069299ea5004 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -50,7 +50,7 @@ TaskCallbackRequest, ) from airflow.dag_processing.dagbag import DagBag -from airflow.dag_processing.manager import DagFileProcessorManager, process_parse_results +from airflow.dag_processing.manager import process_parse_results from airflow.dag_processing.processor import ( DagFileParseRequest, DagFileParsingResult, @@ -706,43 +706,6 @@ def test_import_error_updates_timestamps(): assert stat.import_errors == 1 -def test_persist_parsing_result_calls_update_db(): - """persist_parsing_result should delegate to update_dag_parsing_results_in_db with transformed args.""" - parsing_result = DagFileParsingResult( - fileloc="test.py", - serialized_dags=[], - import_errors={"dags/broken.py": "SyntaxError"}, - warnings=[], - ) - - manager = MagicMock(spec=DagFileProcessorManager) - session = MagicMock() - # Call the real method on the mock instance - with patch( - "airflow.dag_processing.manager.update_dag_parsing_results_in_db", autospec=True - ) as mock_update: - DagFileProcessorManager.persist_parsing_result( - manager, - bundle_name="test-bundle", - bundle_version="v1", - version_data=None, - parsing_result=parsing_result, - run_duration=1.5, - relative_fileloc="dags/test.py", - session=session, - ) - - mock_update.assert_called_once() - call_kwargs = mock_update.call_args.kwargs - assert call_kwargs["bundle_name"] == "test-bundle" - assert call_kwargs["bundle_version"] == "v1" - assert call_kwargs["parse_duration"] == 1.5 - assert call_kwargs["session"] is session - assert call_kwargs["import_errors"] == {("test-bundle", "dags/broken.py"): "SyntaxError"} - assert ("test-bundle", "dags/test.py") in call_kwargs["files_parsed"] - assert ("test-bundle", "dags/broken.py") in call_kwargs["files_parsed"] - - class TestExecuteCallbacks: def test_execute_callbacks_locks_bundle_version(self): callbacks = [ diff --git a/chart/templates/dag-processor/dag-processor-deployment.yaml b/chart/templates/dag-processor/dag-processor-deployment.yaml index 1f42c471afd0d..faf9a761bbb4e 100644 --- a/chart/templates/dag-processor/dag-processor-deployment.yaml +++ b/chart/templates/dag-processor/dag-processor-deployment.yaml @@ -31,6 +31,8 @@ {{- $containerSecurityContextLogGroomerSidecar := include "containerSecurityContext" (list .Values.dagProcessor.logGroomerSidecar .Values) }} {{- $containerSecurityContextWaitForMigrations := include "containerSecurityContext" (list .Values.dagProcessor.waitForMigrations .Values) }} {{- $containerLifecycleHooks := or .Values.dagProcessor.containerLifecycleHooks .Values.containerLifecycleHooks }} +{{- $provisionApiToken := .Values.dagProcessor.apiToken.provisionViaInitContainer }} +{{- $apiTokenDir := dir .Values.dagProcessor.apiToken.path }} apiVersion: apps/v1 kind: Deployment metadata: @@ -137,6 +139,35 @@ spec: {{- tpl (toYaml .Values.dagProcessor.waitForMigrations.env) $ | nindent 12 }} {{- end }} {{- end }} + {{- if $provisionApiToken }} + - name: provision-api-token + # The DAG processor parses user code, so it must not hold the signing key. This trusted + # init container holds it (IncludeJwtSecret true), mints the processor's token into a + # shared volume, and exits; the dag-processor container below reads that file without the + # signing key. Re-created on every pod start, so the token is fresh per restart. + resources: {{- toYaml .Values.dagProcessor.resources | nindent 12 }} + image: {{ template "airflow_image" . }} + imagePullPolicy: {{ .Values.images.airflow.pullPolicy }} + securityContext: {{ $containerSecurityContext | nindent 12 }} + args: + - "bash" + - "-c" + - "exec airflow provision-dag-processor-token" + volumeMounts: + {{- if .Values.volumeMounts }} + {{- toYaml .Values.volumeMounts | nindent 12 }} + {{- end }} + {{- if .Values.dagProcessor.extraVolumeMounts }} + {{- tpl (toYaml .Values.dagProcessor.extraVolumeMounts) . | nindent 12 }} + {{- end }} + - name: dag-processor-api-token + mountPath: {{ $apiTokenDir | quote }} + {{- include "airflow_config_mount" . | nindent 12 }} + envFrom: {{- include "custom_airflow_environment_from" . | default "\n []" | indent 10 }} + env: + {{- include "custom_airflow_environment" . | indent 10 }} + {{- include "standard_airflow_environment" (merge (dict "IncludeJwtSecret" true) .) | indent 10 }} + {{- end }} {{- if and .Values.dags.gitSync.enabled (not .Values.dags.persistence.enabled) }} {{- include "git_sync_container" (dict "Values" .Values "is_init" "true" "Template" .Template) | nindent 8 }} {{- end }} @@ -170,6 +201,11 @@ spec: {{- if .Values.logs.persistence.subPath }} subPath: {{ .Values.logs.persistence.subPath }} {{- end }} + {{- if $provisionApiToken }} + - name: dag-processor-api-token + mountPath: {{ $apiTokenDir | quote }} + readOnly: true + {{- end }} {{- include "airflow_config_mount" . | nindent 12 }} {{- if or .Values.dags.persistence.enabled .Values.dags.gitSync.enabled }} {{- include "airflow_dags_mount" . | nindent 12 }} @@ -277,4 +313,8 @@ spec: - name: logs emptyDir: {{- toYaml (default (dict) .Values.logs.emptyDirConfig) | nindent 12 }} {{- end }} + {{- if $provisionApiToken }} + - name: dag-processor-api-token + emptyDir: {} + {{- end }} {{- end }} diff --git a/chart/tests/helm_tests/airflow_core/test_dag_processor.py b/chart/tests/helm_tests/airflow_core/test_dag_processor.py index 9bc1f6ebb2e84..b92294f4a1aa8 100644 --- a/chart/tests/helm_tests/airflow_core/test_dag_processor.py +++ b/chart/tests/helm_tests/airflow_core/test_dag_processor.py @@ -63,7 +63,9 @@ def test_disable_wait_for_migration(self): actual = jmespath.search( "spec.template.spec.initContainers[?name=='wait-for-airflow-migrations']", docs[0] ) - assert actual is None + # No wait-for-migrations init container (other init containers, e.g. the API token minter, + # may still be present, so this is an empty match rather than a missing initContainers list). + assert not actual def test_wait_for_migration_security_contexts_are_configurable(self): docs = render_chart( diff --git a/chart/values.schema.json b/chart/values.schema.json index bc094f6b5fb67..152ef13743bad 100644 --- a/chart/values.schema.json +++ b/chart/values.schema.json @@ -5742,6 +5742,23 @@ "type": "boolean", "default": true }, + "apiToken": { + "description": "Provisioning of the bearer token the DAG processor presents to the API server. The processor never mints its own token (it parses user code); a trusted init container mints it instead.", + "type": "object", + "additionalProperties": false, + "properties": { + "provisionViaInitContainer": { + "description": "Mint the token in an init container (which holds the signing key) into a shared volume the processor reads. Set false to provision the token yourself via ``extraInitContainers``/``extraVolumes`` and ``config.dag_processor.api_token_path``.", + "type": "boolean", + "default": true + }, + "path": { + "description": "Path of the token file, shared between the init container and the processor.", + "type": "string", + "default": "/opt/airflow/dag-processor-token/token" + } + } + }, "dagBundleConfigList": { "description": "Define Dag bundles in a structured YAML format. This will be automatically converted to JSON string format for config.dag_processor.dag_bundle_config_list.", "type": "array", diff --git a/chart/values.yaml b/chart/values.yaml index 4213d755884cb..1243f58fea43e 100644 --- a/chart/values.yaml +++ b/chart/values.yaml @@ -2687,6 +2687,16 @@ triggerer: dagProcessor: enabled: true + # The DAG processor authenticates to the API server with a bearer token but never mints its own + # (it parses user code, so it must not hold the signing key). By default a trusted init container + # mints one -- using the deployment signing key -- into a shared volume that the processor reads + # and re-reads as it is rotated. Set provisionViaInitContainer to false to provision the token + # yourself via extraInitContainers/extraVolumes plus `config.dag_processor.api_token_path`. + apiToken: + provisionViaInitContainer: true + # Where the minted token is written and read, shared between the init container and the processor. + path: /opt/airflow/dag-processor-token/token + # Dag Bundle Configuration # Define Dag bundles in a structured YAML format. This will be automatically # converted to JSON string format for `config.dag_processor.dag_bundle_config_list`. @@ -3947,6 +3957,9 @@ config: # This value is generated by default from `.Values.dagProcessor.dagBundleConfigList` using the `dag_bundle_config_list` helper function. # It is recommended to configure this via `dagProcessor.dagBundleConfigList` rather than overriding `config.dag_processor.dag_bundle_config_list` directly. dag_bundle_config_list: '{{ include "dag_bundle_config_list" . }}' + # Points the DAG processor at the token the init container mints (see dagProcessor.apiToken). + # Set in config (not as an env var) so it is read by both the init container and the processor. + api_token_path: '{{ if .Values.dagProcessor.apiToken.provisionViaInitContainer }}{{ .Values.dagProcessor.apiToken.path }}{{ end }}' elasticsearch: json_format: 'True' log_id_template: "{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}" diff --git a/task-sdk-integration-tests/docker-compose.yaml b/task-sdk-integration-tests/docker-compose.yaml index 82143e98b5e4a..402ef265bee8f 100644 --- a/task-sdk-integration-tests/docker-compose.yaml +++ b/task-sdk-integration-tests/docker-compose.yaml @@ -34,6 +34,9 @@ x-airflow-common: AIRFLOW__CORE__EXECUTION_API_SERVER_URL: 'http://airflow-apiserver:8080/execution/' AIRFLOW__API__BASE_URL: 'http://airflow-apiserver:8080/' AIRFLOW__API_AUTH__JWT_SECRET: 'test-secret-key-for-testing' + # The DAG processor never mints its own token (it parses user code): airflow-init mints one to + # this shared file and the processor reads it. See airflow-init / [dag_processor] api_token_path. + AIRFLOW__DAG_PROCESSOR__API_TOKEN_PATH: '/opt/airflow/api-token/dag-processor-token' AIRFLOW_VAR_TEST_VARIABLE_KEY: 'test_variable_value' AIRFLOW_CONN_TEST_CONNECTION: 'postgresql://testuser:testpass@testhost:5432/testdb' HOST_OS: ${HOST_OS:-linux} @@ -41,6 +44,7 @@ x-airflow-common: volumes: - ./dags:/opt/airflow/dags - ./logs:/opt/airflow/logs + - airflow-api-token:/opt/airflow/api-token depends_on: &airflow-common-depends-on postgres: @@ -72,6 +76,8 @@ services: /entrypoint airflow version echo "Running airflow config list to create default config file if missing." /entrypoint airflow config list >/dev/null + echo "Minting the DAG processor API token (processor parses user code, so it never mints its own)." + /entrypoint airflow provision-dag-processor-token if [ "${HOST_OS}" == "linux" ]; then echo "Change ownership of files in /opt/airflow to ${AIRFLOW_UID}:0" chown -R "${AIRFLOW_UID}:0" /opt/airflow/ @@ -128,3 +134,7 @@ services: <<: *airflow-common-depends-on airflow-init: condition: service_completed_successfully + +volumes: + # Shares the DAG processor API token from airflow-init (which mints it) to the processor. + airflow-api-token: