Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def fetch_task_instances(
dag_id: str | None = None,
run_id: str | None = None,
task_ids: list[str] | None = None,
state: Iterable[TaskInstanceState | None] | None = None,
state: TaskInstanceState | Iterable[TaskInstanceState | None] | None = None,
*,
session: Session = NEW_SESSION,
) -> list[TI]:
Expand Down Expand Up @@ -917,7 +917,7 @@ def _check_last_n_dagruns_failed(self, dag_id, max_consecutive_failed_dag_runs,
@provide_session
def get_task_instances(
self,
state: Iterable[TaskInstanceState | None] | None = None,
state: TaskInstanceState | Iterable[TaskInstanceState | None] | None = None,
*,
session: Session = NEW_SESSION,
) -> list[TI]:
Expand Down
36 changes: 36 additions & 0 deletions airflow-core/tests/unit/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,42 @@ def test_get_task_instance_on_empty_dagrun(self, dag_maker, session):
ti = dag_run.get_task_instance("test_short_circuit_false")
assert ti is None

@pytest.mark.parametrize(
("state_filter", "expected_task_ids"),
[
pytest.param(TaskInstanceState.SUCCESS, {"success_task"}, id="single-state"),
pytest.param(
[TaskInstanceState.SUCCESS, TaskInstanceState.FAILED],
{"success_task", "failed_task"},
id="iterable-of-states",
),
],
)
def test_get_task_instances_state_accepts_single_or_iterable(
self, dag_maker, session, state_filter, expected_task_ids
):
with dag_maker(
dag_id="test_get_task_instances_state",
schedule=datetime.timedelta(days=1),
start_date=DEFAULT_DATE,
) as dag:
EmptyOperator(task_id="success_task")
EmptyOperator(task_id="failed_task")
EmptyOperator(task_id="skipped_task")

dag_run = self.create_dag_run(
dag=dag,
task_states={
"success_task": TaskInstanceState.SUCCESS,
"failed_task": TaskInstanceState.FAILED,
"skipped_task": TaskInstanceState.SKIPPED,
},
session=session,
)

tis = dag_run.get_task_instances(state=state_filter, session=session)
assert {ti.task_id for ti in tis} == expected_task_ids

def test_get_latest_runs(self, dag_maker, session):
with dag_maker(
dag_id="test_latest_runs_1", schedule=datetime.timedelta(days=1), start_date=DEFAULT_DATE
Expand Down
Loading