Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions src/strands_tools/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,28 @@ def get_ready_tasks(self, workflow: Dict) -> List[Dict]:
ready_tasks.sort(key=lambda x: x.get("priority", 3), reverse=True)
return ready_tasks

def _get_task_timeout(self, workflow: Dict, task_id: str) -> float:
"""Return the configured timeout (seconds) for a task, defaulting to 300."""
for task in workflow["tasks"]:
if task["task_id"] == task_id:
return task.get("timeout", 300)
return 300

def _next_task_deadline(self, active_futures: Dict, workflow: Dict) -> Optional[float]:
"""Return seconds until the soonest active-task deadline, or None if no active tasks."""
if not active_futures:
return None
now = time.time()
soonest = None
for namespaced_task_id in active_futures:
task_id = namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
task_timeout = self._get_task_timeout(workflow, task_id)
start_time = self.task_executor.start_times.get(namespaced_task_id, now)
deadline_remaining = (start_time + task_timeout) - now
if soonest is None or deadline_remaining < soonest:
soonest = deadline_remaining
return max(0.0, soonest) if soonest is not None else None

def start_workflow(self, workflow_id: str) -> Dict:
"""Start or resume workflow execution with true parallel processing."""
try:
Expand Down Expand Up @@ -633,19 +655,22 @@ def start_workflow(self, workflow_id: str) -> Dict:
active_futures.update(new_futures)
logger.debug(f"📤 Submitted {len(tasks_to_submit)} tasks for execution")

# Wait for any task to complete
# Wait for any task to complete, but no longer than the next task deadline
if active_futures:
done, _ = wait(active_futures.values(), return_when=FIRST_COMPLETED)
wait_timeout = self._next_task_deadline(active_futures, workflow)
done, _ = wait(active_futures.values(), return_when=FIRST_COMPLETED, timeout=wait_timeout)

# Process completed tasks
completed_task_ids = []
# Process completed and timed-out tasks
finished_namespaced_ids = []
now = time.time()
for namespaced_task_id, future in active_futures.items():
# Extract original task_id from namespaced version
task_id = (
namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
)

if future in done:
# Extract original task_id from namespaced version
task_id = (
namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
)
completed_task_ids.append(namespaced_task_id)
finished_namespaced_ids.append(namespaced_task_id)
try:
result = future.result()

Expand Down Expand Up @@ -676,9 +701,27 @@ def start_workflow(self, workflow_id: str) -> Dict:
}
completed_tasks.add(task_id)
logger.error(f"❌ Task '{task_id}' failed: {str(e)}")
else:
# Still-running future: enforce per-task timeout. We can't reliably
# terminate the worker thread (Python lacks a portable mechanism),
# but we mark the task failed so dependent tasks and the overall
# workflow can make progress instead of blocking indefinitely.
task_timeout = self._get_task_timeout(workflow, task_id)
start_time = self.task_executor.start_times.get(namespaced_task_id, now)
if (now - start_time) >= task_timeout:
future.cancel()
workflow["task_results"][task_id] = {
**workflow["task_results"][task_id],
"status": "error",
"result": [{"text": f"Task execution timeout after {task_timeout}s"}],
"completed_at": datetime.now(timezone.utc).isoformat(),
}
completed_tasks.add(task_id)
logger.error(f"⏱️ Task '{task_id}' timed out after {task_timeout}s")
finished_namespaced_ids.append(namespaced_task_id)

# Remove completed tasks from active futures
for task_id in completed_task_ids:
# Remove completed and timed-out tasks from active futures
for task_id in finished_namespaced_ids:
del active_futures[task_id]

# Store updated workflow state
Expand Down
65 changes: 65 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import json
import tempfile
import time
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -260,6 +262,69 @@ def test_task_id_namespacing(self):
extracted_id = namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
assert extracted_id == "task1"

def test_start_workflow_enforces_task_timeout(self, mock_parent_agent):
"""A task that exceeds its configured timeout is marked as error so the workflow can finish."""
# WorkflowManager is a singleton (__new__-based), so reset it before AND after
# this test to avoid leaking our patched execute_task / shut-down TaskExecutor
# into other tests.
workflow_module.WorkflowManager._instance = None
try:
with (
tempfile.TemporaryDirectory() as temp_dir,
patch.object(workflow_module, "WORKFLOW_DIR", Path(temp_dir)),
):
manager = workflow_module.WorkflowManager(mock_parent_agent)

# execute_task replacement that sleeps far longer than the configured task timeout
def hanging_execute_task(task, workflow):
time.sleep(2.0)
return {"status": "success", "content": [{"text": "should not be reached"}]}

manager.execute_task = hanging_execute_task

workflow_id = "timeout_test"
workflow_data = {
"workflow_id": workflow_id,
"created_at": datetime.now(timezone.utc).isoformat(),
"status": "created",
"tasks": [
{
"task_id": "slow_task",
"description": "Hangs longer than its timeout",
"timeout": 0.2,
"priority": 3,
"dependencies": [],
}
],
"task_results": {
"slow_task": {
"status": "pending",
"result": None,
"priority": 3,
"model_provider": None,
"tools": [],
}
},
"parallel_execution": True,
}
manager.store_workflow(workflow_id, workflow_data)

started = time.time()
result = manager.start_workflow(workflow_id)
elapsed = time.time() - started

# Workflow should give up on the hung task well before its 2.0s sleep completes
assert elapsed < 1.5, f"start_workflow blocked on a hung task ({elapsed:.2f}s)"
assert result["status"] == "success"

task_result = manager.get_workflow(workflow_id)["task_results"]["slow_task"]
assert task_result["status"] == "error"
timeout_text = task_result["result"][0]["text"]
assert "timeout" in timeout_text.lower()
assert "0.2" in timeout_text
finally:
workflow_module.WorkflowManager._instance = None


class TestWorkflowStatus:
"""Test workflow status functionality."""
Expand Down
Loading