diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 25f6397c074ec..36ef372e39aa2 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -836,7 +836,7 @@ def fetch_task_instances( dag_id: str | None = None, run_id: str | None = None, task_ids: list[str] | None = None, - state: Iterable[TaskInstanceState | None] | None = None, + state: TaskInstanceState | Iterable[TaskInstanceState | None] | None = None, *, session: Session = NEW_SESSION, ) -> list[TI]: @@ -917,7 +917,7 @@ def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs, @provide_session def get_task_instances( self, - state: Iterable[TaskInstanceState | None] | None = None, + state: TaskInstanceState | Iterable[TaskInstanceState | None] | None = None, *, session: Session = NEW_SESSION, ) -> list[TI]: diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 9405d8a955512..210bd3316f457 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -780,6 +780,42 @@ def test_get_task_instance_on_empty_dagrun(self, dag_maker, session): ti = dag_run.get_task_instance("test_short_circuit_false") assert ti is None + @pytest.mark.parametrize( + ("state_filter", "expected_task_ids"), + [ + pytest.param(TaskInstanceState.SUCCESS, {"success_task"}, id="single-state"), + pytest.param( + [TaskInstanceState.SUCCESS, TaskInstanceState.FAILED], + {"success_task", "failed_task"}, + id="iterable-of-states", + ), + ], + ) + def test_get_task_instances_state_accepts_single_or_iterable( + self, dag_maker, session, state_filter, expected_task_ids + ): + with dag_maker( + dag_id="test_get_task_instances_state", + schedule=datetime.timedelta(days=1), + start_date=DEFAULT_DATE, + ) as dag: + EmptyOperator(task_id="success_task") + EmptyOperator(task_id="failed_task") + EmptyOperator(task_id="skipped_task") + + dag_run = self.create_dag_run( + dag=dag, + task_states={ + "success_task": TaskInstanceState.SUCCESS, + "failed_task": TaskInstanceState.FAILED, + "skipped_task": TaskInstanceState.SKIPPED, + }, + session=session, + ) + + tis = dag_run.get_task_instances(state=state_filter, session=session) + assert {ti.task_id for ti in tis} == expected_task_ids + def test_get_latest_runs(self, dag_maker, session): with dag_maker( dag_id="test_latest_runs_1", schedule=datetime.timedelta(days=1), start_date=DEFAULT_DATE