diff --git a/src/strands_tools/workflow.py b/src/strands_tools/workflow.py index c9dd0555..ce34816f 100644 --- a/src/strands_tools/workflow.py +++ b/src/strands_tools/workflow.py @@ -581,6 +581,28 @@ def get_ready_tasks(self, workflow: Dict) -> List[Dict]: ready_tasks.sort(key=lambda x: x.get("priority", 3), reverse=True) return ready_tasks + def _get_task_timeout(self, workflow: Dict, task_id: str) -> float: + """Return the configured timeout (seconds) for a task, defaulting to 300.""" + for task in workflow["tasks"]: + if task["task_id"] == task_id: + return task.get("timeout", 300) + return 300 + + def _next_task_deadline(self, active_futures: Dict, workflow: Dict) -> Optional[float]: + """Return seconds until the soonest active-task deadline, or None if no active tasks.""" + if not active_futures: + return None + now = time.time() + soonest = None + for namespaced_task_id in active_futures: + task_id = namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id + task_timeout = self._get_task_timeout(workflow, task_id) + start_time = self.task_executor.start_times.get(namespaced_task_id, now) + deadline_remaining = (start_time + task_timeout) - now + if soonest is None or deadline_remaining < soonest: + soonest = deadline_remaining + return max(0.0, soonest) if soonest is not None else None + def start_workflow(self, workflow_id: str) -> Dict: """Start or resume workflow execution with true parallel processing.""" try: @@ -633,19 +655,22 @@ def start_workflow(self, workflow_id: str) -> Dict: active_futures.update(new_futures) logger.debug(f"📤 Submitted {len(tasks_to_submit)} tasks for execution") - # Wait for any task to complete + # Wait for any task to complete, but no longer than the next task deadline if active_futures: - done, _ = wait(active_futures.values(), return_when=FIRST_COMPLETED) + wait_timeout = self._next_task_deadline(active_futures, workflow) + done, _ = wait(active_futures.values(), return_when=FIRST_COMPLETED, timeout=wait_timeout) - # Process completed tasks - completed_task_ids = [] + # Process completed and timed-out tasks + finished_namespaced_ids = [] + now = time.time() for namespaced_task_id, future in active_futures.items(): + # Extract original task_id from namespaced version + task_id = ( + namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id + ) + if future in done: - # Extract original task_id from namespaced version - task_id = ( - namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id - ) - completed_task_ids.append(namespaced_task_id) + finished_namespaced_ids.append(namespaced_task_id) try: result = future.result() @@ -676,9 +701,27 @@ def start_workflow(self, workflow_id: str) -> Dict: } completed_tasks.add(task_id) logger.error(f"❌ Task '{task_id}' failed: {str(e)}") + else: + # Still-running future: enforce per-task timeout. We can't reliably + # terminate the worker thread (Python lacks a portable mechanism), + # but we mark the task failed so dependent tasks and the overall + # workflow can make progress instead of blocking indefinitely. + task_timeout = self._get_task_timeout(workflow, task_id) + start_time = self.task_executor.start_times.get(namespaced_task_id, now) + if (now - start_time) >= task_timeout: + future.cancel() + workflow["task_results"][task_id] = { + **workflow["task_results"][task_id], + "status": "error", + "result": [{"text": f"Task execution timeout after {task_timeout}s"}], + "completed_at": datetime.now(timezone.utc).isoformat(), + } + completed_tasks.add(task_id) + logger.error(f"⏱️ Task '{task_id}' timed out after {task_timeout}s") + finished_namespaced_ids.append(namespaced_task_id) - # Remove completed tasks from active futures - for task_id in completed_task_ids: + # Remove completed and timed-out tasks from active futures + for task_id in finished_namespaced_ids: del active_futures[task_id] # Store updated workflow state diff --git a/tests/test_workflow.py b/tests/test_workflow.py index bfc35b3f..1b738f68 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -4,6 +4,8 @@ import json import tempfile +import time +from datetime import datetime, timezone from pathlib import Path from unittest.mock import MagicMock, patch @@ -260,6 +262,69 @@ def test_task_id_namespacing(self): extracted_id = namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id assert extracted_id == "task1" + def test_start_workflow_enforces_task_timeout(self, mock_parent_agent): + """A task that exceeds its configured timeout is marked as error so the workflow can finish.""" + # WorkflowManager is a singleton (__new__-based), so reset it before AND after + # this test to avoid leaking our patched execute_task / shut-down TaskExecutor + # into other tests. + workflow_module.WorkflowManager._instance = None + try: + with ( + tempfile.TemporaryDirectory() as temp_dir, + patch.object(workflow_module, "WORKFLOW_DIR", Path(temp_dir)), + ): + manager = workflow_module.WorkflowManager(mock_parent_agent) + + # execute_task replacement that sleeps far longer than the configured task timeout + def hanging_execute_task(task, workflow): + time.sleep(2.0) + return {"status": "success", "content": [{"text": "should not be reached"}]} + + manager.execute_task = hanging_execute_task + + workflow_id = "timeout_test" + workflow_data = { + "workflow_id": workflow_id, + "created_at": datetime.now(timezone.utc).isoformat(), + "status": "created", + "tasks": [ + { + "task_id": "slow_task", + "description": "Hangs longer than its timeout", + "timeout": 0.2, + "priority": 3, + "dependencies": [], + } + ], + "task_results": { + "slow_task": { + "status": "pending", + "result": None, + "priority": 3, + "model_provider": None, + "tools": [], + } + }, + "parallel_execution": True, + } + manager.store_workflow(workflow_id, workflow_data) + + started = time.time() + result = manager.start_workflow(workflow_id) + elapsed = time.time() - started + + # Workflow should give up on the hung task well before its 2.0s sleep completes + assert elapsed < 1.5, f"start_workflow blocked on a hung task ({elapsed:.2f}s)" + assert result["status"] == "success" + + task_result = manager.get_workflow(workflow_id)["task_results"]["slow_task"] + assert task_result["status"] == "error" + timeout_text = task_result["result"][0]["text"] + assert "timeout" in timeout_text.lower() + assert "0.2" in timeout_text + finally: + workflow_module.WorkflowManager._instance = None + class TestWorkflowStatus: """Test workflow status functionality."""