diff --git a/README.md b/README.md index e945edf4..76be0fb9 100644 --- a/README.md +++ b/README.md @@ -693,9 +693,13 @@ agent = Agent(tools=provider.tools) response = agent("discover available agents and send a greeting message") # The agent will automatically use the available tools: -# - discover_agent(url) to find agents -# - list_discovered_agents() to see all discovered agents -# - send_message(message_text, target_agent_url) to communicate +# - a2a_discover_agent(url) to find agents +# - a2a_list_discovered_agents() to see all discovered agents +# - a2a_send_message(message_text, target_agent_url) to communicate +# - a2a_get_task(target_agent_url, task_id) to monitor task status and retrieve history +# - a2a_cancel_task(target_agent_url, task_id) to cancel running tasks +# - a2a_get_conversation_state(target_agent_url) to check current conversation state +# - a2a_clear_conversation_state(target_agent_url) to start a fresh conversation ``` ### Diagram diff --git a/src/strands_tools/a2a_client.py b/src/strands_tools/a2a_client.py index a653f860..d91b4927 100644 --- a/src/strands_tools/a2a_client.py +++ b/src/strands_tools/a2a_client.py @@ -6,8 +6,11 @@ Key Features: - Agent discovery through agent cards from multiple URLs - Message sending to specific A2A agents +- Task monitoring and status retrieval with history and artifacts +- Task cancellation with terminal state validation - Push notification support for real-time task completion alerts - Custom authentication support via httpx client arguments +- Context and task ID persistence for multi-turn conversations Usage Examples: @@ -24,24 +27,66 @@ ... "timeout": 300 ... } ... ) + + Multi-turn conversation with context persistence: + >>> provider = A2AClientToolProvider(known_agent_urls=["http://agent.example.com"]) + >>> # First message - context_id and task_id will be returned in response + >>> result1 = await provider.a2a_send_message("Start a task", "http://agent.example.com") + >>> # Second message - context_id is automatically reused for the same agent + >>> result2 = await provider.a2a_send_message("Continue the task", "http://agent.example.com") + + Task monitoring and status retrieval: + >>> provider = A2AClientToolProvider(known_agent_urls=["http://agent.example.com"]) + >>> # Get task status with full history + >>> task_status = await provider.a2a_get_task("http://agent.example.com", "task-123") + >>> # Get task status with limited history for better performance + >>> task_status = await provider.a2a_get_task("http://agent.example.com", "task-456", history_length=10) + + Task cancellation: + >>> provider = A2AClientToolProvider(known_agent_urls=["http://agent.example.com"]) + >>> # Cancel a running task + >>> result = await provider.a2a_cancel_task("http://agent.example.com", "task-123") + >>> # Cancel with context_id validation for security + >>> result = await provider.a2a_cancel_task("http://agent.example.com", "task-456", context_id="ctx-789") """ import asyncio import logging +from dataclasses import dataclass, field from typing import Any from uuid import uuid4 import httpx from a2a.client import A2ACardResolver, ClientConfig, ClientFactory -from a2a.types import AgentCard, Message, Part, PushNotificationConfig, Role, TextPart +from a2a.types import AgentCard, Message, Part, PushNotificationConfig, Role, TaskIdParams, TaskQueryParams, TextPart from strands import tool from strands.types.tools import AgentTool DEFAULT_TIMEOUT = 300 # set request timeout to 5 minutes +# Terminal task states - tasks in these states should not be continued +TERMINAL_TASK_STATES = {"completed", "canceled", "failed", "rejected"} + logger = logging.getLogger(__name__) +@dataclass +class ActiveTask: + """Represents an active (non-terminal) task for an agent.""" + + task_id: str + state: str + context_id: str + + +@dataclass +class ConversationState: + """Tracks conversation state for a target agent URL.""" + + context_id: str | None = None + active_tasks: dict[str, ActiveTask] = field(default_factory=dict) # task_id -> ActiveTask + + class A2AClientToolProvider: """A2A Client tool provider that manages multiple A2A agents and exposes synchronous tools.""" @@ -93,6 +138,10 @@ def __init__( id=f"strands-webhook-{uuid4().hex[:8]}", url=self._webhook_url, token=self._webhook_token ) + # Conversation state tracking for context_id and task_id persistence + # Key: target_agent_url, Value: ConversationState + self._conversation_states: dict[str, ConversationState] = {} + @property def tools(self) -> list[AgentTool]: """Extract all @tool decorated methods from this instance.""" @@ -177,6 +226,62 @@ async def _discover_agent_card(self, url: str) -> AgentCard: return agent_card + def _get_conversation_state(self, target_agent_url: str) -> ConversationState: + """Get or create conversation state for a target agent URL.""" + if target_agent_url not in self._conversation_states: + self._conversation_states[target_agent_url] = ConversationState() + return self._conversation_states[target_agent_url] + + def _update_conversation_state( + self, + target_agent_url: str, + context_id: str | None = None, + task_id: str | None = None, + task_state: str | None = None, + ) -> None: + """ + Update conversation state from server response. + + Args: + target_agent_url: The agent URL this state belongs to + context_id: The context ID from the response (if any) + task_id: The task ID from the response (if any) + task_state: The task state from the response (if any) + """ + state = self._get_conversation_state(target_agent_url) + + # Store context_id if we don't have one yet + if context_id and not state.context_id: + state.context_id = context_id + logger.debug(f"Stored context_id={context_id} for {target_agent_url}") + + # Track task state + if task_id and task_state: + if task_state in TERMINAL_TASK_STATES: + # Remove task from active tracking when it reaches a terminal state + if task_id in state.active_tasks: + del state.active_tasks[task_id] + logger.debug(f"Removed terminal task {task_id} (state={task_state}) for {target_agent_url}") + else: + # Update or add active task + if state.context_id: + state.active_tasks[task_id] = ActiveTask( + task_id=task_id, state=task_state, context_id=state.context_id + ) + logger.debug(f"Updated active task {task_id} (state={task_state}) for {target_agent_url}") + + def _get_task_id_for_continuation(self, target_agent_url: str) -> str | None: + """ + Get the task_id to use for continuing a conversation. + + Returns: + The task_id if there's exactly one active non-terminal task, None otherwise. + """ + state = self._get_conversation_state(target_agent_url) + if len(state.active_tasks) == 1: + return next(iter(state.active_tasks.values())).task_id + return None + @tool async def a2a_discover_agent(self, url: str) -> dict[str, Any]: """ @@ -252,7 +357,12 @@ async def _list_discovered_agents(self) -> dict[str, Any]: @tool async def a2a_send_message( - self, message_text: str, target_agent_url: str, message_id: str | None = None + self, + message_text: str, + target_agent_url: str, + message_id: str | None = None, + context_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any]: """ Send a message to a specific A2A agent and return the response. @@ -261,11 +371,24 @@ async def a2a_send_message( refers to an agent by name only, use a2a_list_discovered_agents first to get the correct URL. Never guess, generate, or hallucinate URLs. + For multi-turn conversations: + - context_id: Automatically persisted from the first response and reused + in subsequent messages to the same agent. Can be explicitly overridden. + - task_id: If there's exactly one active non-terminal task for the agent, + it will be automatically reused. Can be explicitly provided to continue + a specific task, or omitted to create a new task. + Args: message_text: The message content to send to the agent target_agent_url: The exact URL of the target A2A agent (user-provided URL or from a2a_list_discovered_agents) message_id: Optional message ID for tracking (generates UUID if not provided) + context_id: Optional context ID for continuing a conversation. + If not provided, uses the persisted context_id from previous + interactions with this agent (if any). + task_id: Optional task ID for continuing a specific task. + If not provided and there's exactly one active task for this agent, + that task_id will be used automatically. Returns: dict: Response data including: @@ -274,11 +397,18 @@ async def a2a_send_message( - error: Error message (if failed) - message_id: The message ID used - target_agent_url: The agent URL that was contacted + - context_id: The context ID used/returned (for conversation continuity) + - task_id: The task ID used/returned (if applicable) """ - return await self._send_message(message_text, target_agent_url, message_id) + return await self._send_message(message_text, target_agent_url, message_id, context_id, task_id) async def _send_message( - self, message_text: str, target_agent_url: str, message_id: str | None = None + self, + message_text: str, + target_agent_url: str, + message_id: str | None = None, + context_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any]: """Internal async implementation for send_message.""" @@ -293,38 +423,87 @@ async def _send_message( if message_id is None: message_id = uuid4().hex + # Resolve context_id: explicit > persisted + effective_context_id = context_id + if effective_context_id is None: + state = self._get_conversation_state(target_agent_url) + effective_context_id = state.context_id + + # Resolve task_id: explicit > auto-continuation + effective_task_id = task_id + if effective_task_id is None: + effective_task_id = self._get_task_id_for_continuation(target_agent_url) + message = Message( kind="message", role=Role.user, parts=[Part(TextPart(kind="text", text=message_text))], message_id=message_id, + context_id=effective_context_id, + task_id=effective_task_id, ) - logger.info(f"Sending message to {target_agent_url}") + logger.info( + f"Sending message to {target_agent_url} " + f"(context_id={effective_context_id}, task_id={effective_task_id})" + ) # With streaming=False, this will yield exactly one result async for event in client.send_message(message): + response_context_id = None + response_task_id = None + response_task_state = None + if isinstance(event, Message): # Direct message response + response_data = event.model_dump(mode="python", exclude_none=True) + response_context_id = getattr(event, "context_id", None) + response_task_id = getattr(event, "task_id", None) + + # Update conversation state from response + self._update_conversation_state( + target_agent_url, response_context_id, response_task_id, response_task_state + ) + return { "status": "success", - "response": event.model_dump(mode="python", exclude_none=True), + "response": response_data, "message_id": message_id, "target_agent_url": target_agent_url, + "context_id": response_context_id or effective_context_id, + "task_id": response_task_id or effective_task_id, } elif isinstance(event, tuple) and len(event) == 2: # (Task, UpdateEvent) tuple - extract the task task, update_event = event + task_data = task.model_dump(mode="python", exclude_none=True) + + # Extract IDs and state from task + response_context_id = getattr(task, "context_id", None) + response_task_id = getattr(task, "id", None) + task_status = getattr(task, "status", None) + if task_status: + response_task_state = getattr(task_status, "state", None) + if hasattr(response_task_state, "value"): + response_task_state = response_task_state.value + + # Update conversation state from response + self._update_conversation_state( + target_agent_url, response_context_id, response_task_id, response_task_state + ) + return { "status": "success", "response": { - "task": task.model_dump(mode="python", exclude_none=True), + "task": task_data, "update": ( update_event.model_dump(mode="python", exclude_none=True) if update_event else None ), }, "message_id": message_id, "target_agent_url": target_agent_url, + "context_id": response_context_id or effective_context_id, + "task_id": response_task_id or effective_task_id, } else: # Fallback for unexpected response types @@ -333,6 +512,8 @@ async def _send_message( "response": {"raw_response": str(event)}, "message_id": message_id, "target_agent_url": target_agent_url, + "context_id": effective_context_id, + "task_id": effective_task_id, } # This should never be reached with streaming=False @@ -341,6 +522,8 @@ async def _send_message( "error": "No response received from agent", "message_id": message_id, "target_agent_url": target_agent_url, + "context_id": effective_context_id, + "task_id": effective_task_id, } except Exception as e: @@ -350,4 +533,296 @@ async def _send_message( "error": str(e), "message_id": message_id, "target_agent_url": target_agent_url, + "context_id": context_id, + "task_id": task_id, + } + + @tool + async def a2a_get_conversation_state(self, target_agent_url: str) -> dict[str, Any]: + """ + Get the current conversation state for a target agent. + + This returns the persisted context_id and active tasks for the specified + agent URL, useful for debugging or understanding the conversation state. + + Args: + target_agent_url: The URL of the target A2A agent + + Returns: + dict: Conversation state including: + - context_id: The persisted context ID (if any) + - active_tasks: List of active (non-terminal) tasks + - target_agent_url: The agent URL queried + """ + state = self._get_conversation_state(target_agent_url) + return { + "status": "success", + "context_id": state.context_id, + "active_tasks": [ + {"task_id": task.task_id, "state": task.state, "context_id": task.context_id} + for task in state.active_tasks.values() + ], + "target_agent_url": target_agent_url, + } + + @tool + async def a2a_clear_conversation_state(self, target_agent_url: str) -> dict[str, Any]: + """ + Clear the conversation state for a target agent. + + This removes the persisted context_id and all active tasks for the + specified agent URL, effectively starting a new conversation. + + Args: + target_agent_url: The URL of the target A2A agent + + Returns: + dict: Result of the operation including: + - status: "success" if cleared successfully + - target_agent_url: The agent URL that was cleared + """ + if target_agent_url in self._conversation_states: + del self._conversation_states[target_agent_url] + logger.info(f"Cleared conversation state for {target_agent_url}") + + return { + "status": "success", + "target_agent_url": target_agent_url, + } + + @tool + async def a2a_get_task( + self, + target_agent_url: str, + task_id: str, + history_length: int | None = None, + ) -> dict[str, Any]: + """ + Get the current state and details of a specific task. + + This retrieves comprehensive information about a task including its + status, history, artifacts, and metadata. Useful for monitoring task + progress, debugging, or checking completion status. + + Args: + target_agent_url: The URL of the target A2A agent + task_id: The ID of the task to retrieve + history_length: Optional number of recent history items to include. + If not specified, returns full history. Use lower values for + better performance on tasks with extensive history. + + Returns: + dict: Task retrieval result including: + - status: "success" or "error" + - task: The complete task data (if successful) + - task_state: The current state of the task + - context_id: The task's context ID + - error: Error message (if failed) + - target_agent_url: The agent URL contacted + - task_id: The task ID queried + """ + return await self._get_task(target_agent_url, task_id, history_length) + + async def _get_task( + self, + target_agent_url: str, + task_id: str, + history_length: int | None = None, + ) -> dict[str, Any]: + """Internal async implementation for get_task.""" + try: + await self._ensure_discovered_known_agents() + + # Get the agent card and create client + agent_card = await self._discover_agent_card(target_agent_url) + client_factory = self._get_client_factory() + client = client_factory.create(agent_card) + + # Build task query parameters + task_params = TaskQueryParams(id=task_id, history_length=history_length) + + logger.info( + f"Retrieving task {task_id} from {target_agent_url}" + f"{f' (history_length={history_length})' if history_length else ''}" + ) + + # Get the task + task = await client.get_task(task_params) + + # Extract task state and context + task_state_value = task.status.state + if hasattr(task_state_value, "value"): + task_state_value = task_state_value.value + + task_context_id = getattr(task, "context_id", None) + + # Update conversation state with current task info + logger.info(f"Updating the conversation state for {task.id}, state: {task_state_value}") + self._update_conversation_state( + target_agent_url, + context_id=task_context_id, + task_id=task.id, + task_state=task_state_value, + ) + + logger.info(f"Successfully retrieved task {task_id}, state: {task_state_value}") + + return { + "status": "success", + "task": task.model_dump(mode="python", exclude_none=True), + "task_id": task_id, + "task_state": task_state_value, + "context_id": task_context_id, + "target_agent_url": target_agent_url, + } + + except Exception as e: + logger.exception(f"Error retrieving task {task_id} from {target_agent_url}") + error_type = type(e).__name__ + + # Check if it's a "task not found" error + error_message = str(e) + if "not found" in error_message.lower() or "404" in error_message: + return { + "status": "error", + "error": f"Task not found: {error_message}", + "error_type": error_type, + "task_id": task_id, + "target_agent_url": target_agent_url, + } + + return { + "status": "error", + "error": str(e), + "error_type": error_type, + "task_id": task_id, + "target_agent_url": target_agent_url, + } + + @tool + async def a2a_cancel_task( + self, + target_agent_url: str, + task_id: str, + context_id: str | None = None, + ) -> dict[str, Any]: + """ + Cancel a running task on a specific A2A agent. + + This attempts to cancel an active task. It validates that the task + is not already in a terminal state before attempting cancellation. + + IMPORTANT: Task must be in a non-terminal state (not completed, canceled, + failed, or rejected) to be cancelable. + + Args: + target_agent_url: The URL of the target A2A agent + task_id: The ID of the task to cancel + context_id: Optional context ID for validation against the task's context + + Returns: + dict: Cancellation result including: + - status: "success" or "error" + - task: The canceled task data (if successful) + - task_state: The current/final state of the task + - error: Error message (if failed) + - target_agent_url: The agent URL contacted + - task_id: The task ID that was canceled + """ + return await self._cancel_task(target_agent_url, task_id, context_id) + + async def _cancel_task( + self, + target_agent_url: str, + task_id: str, + context_id: str | None = None, + ) -> dict[str, Any]: + """Internal async implementation for cancel_task.""" + try: + await self._ensure_discovered_known_agents() + + # Get the agent card and create client + agent_card = await self._discover_agent_card(target_agent_url) + client_factory = self._get_client_factory() + client = client_factory.create(agent_card) + + # First, get the task to check its current state + try: + current_task = await client.get_task(TaskQueryParams(id=task_id)) + except Exception as e: + logger.error(f"Failed to retrieve task {task_id} before cancellation: {e}") + return { + "status": "error", + "error": f"Task not found or inaccessible: {str(e)}", + "task_id": task_id, + "target_agent_url": target_agent_url, + } + + # Check if task is in a terminal state + task_state_value = current_task.status.state + if hasattr(task_state_value, "value"): + task_state_value = task_state_value.value + + if task_state_value in TERMINAL_TASK_STATES: + logger.warning(f"Task {task_id} cannot be canceled - already in terminal state: {task_state_value}") + return { + "status": "error", + "error": f"Task cannot be canceled - current state: {task_state_value}", + "task_id": task_id, + "target_agent_url": target_agent_url, + "task_state": task_state_value, + } + + # Validate context_id if provided + if context_id and current_task.context_id != context_id: + logger.error( + f"Context ID mismatch for task {task_id}: expected {context_id}, got {current_task.context_id}" + ) + return { + "status": "error", + "error": f"Context ID mismatch: expected {context_id}, got {current_task.context_id}", + "task_id": task_id, + "target_agent_url": target_agent_url, + } + + logger.info(f"Canceling task {task_id} on {target_agent_url}") + + # Cancel the task + task_params = TaskIdParams(id=task_id) + + canceled_task = await client.cancel_task(task_params) + + # Extract final state + final_state = canceled_task.status.state + if hasattr(final_state, "value"): + final_state = final_state.value + + # Cancellation should always result in a terminal state. + if final_state in TERMINAL_TASK_STATES: # should be true, if canceled successfully + self._update_conversation_state( + target_agent_url, + context_id=canceled_task.context_id, + task_id=canceled_task.id, + task_state=final_state, + ) + + logger.info(f"Successfully canceled task {task_id}, final state: {final_state}") + + return { + "status": "success", + "task": canceled_task.model_dump(mode="python", exclude_none=True), + "task_id": task_id, + "target_agent_url": target_agent_url, + "task_state": final_state, + } + + except Exception as e: + logger.exception(f"Error canceling task {task_id} on {target_agent_url}") + error_type = type(e).__name__ + return { + "status": "error", + "error": str(e), + "error_type": error_type, + "task_id": task_id, + "target_agent_url": target_agent_url, } diff --git a/tests/test_a2a_client.py b/tests/test_a2a_client.py index d8977e6f..e9495b76 100644 --- a/tests/test_a2a_client.py +++ b/tests/test_a2a_client.py @@ -3,7 +3,13 @@ import pytest from a2a.types import Message -from strands_tools.a2a_client import DEFAULT_TIMEOUT, A2AClientToolProvider +from strands_tools.a2a_client import ( + DEFAULT_TIMEOUT, + TERMINAL_TASK_STATES, + A2AClientToolProvider, + ActiveTask, + ConversationState, +) def test_init_default_parameters(): @@ -14,6 +20,7 @@ def test_init_default_parameters(): assert provider._known_agent_urls == [] assert provider._discovered_agents == {} assert provider._httpx_client_args == {"timeout": DEFAULT_TIMEOUT} + assert provider._conversation_states == {} def test_init_custom_parameters(): @@ -56,11 +63,15 @@ def test_tools_property(): provider = A2AClientToolProvider() tools = provider.tools - # Should have the three @tool decorated methods + # Should have the seven @tool decorated methods (including task management tools) tool_names = [tool.tool_name for tool in tools] assert "a2a_discover_agent" in tool_names assert "a2a_list_discovered_agents" in tool_names assert "a2a_send_message" in tool_names + assert "a2a_get_conversation_state" in tool_names + assert "a2a_clear_conversation_state" in tool_names + assert "a2a_get_task" in tool_names + assert "a2a_cancel_task" in tool_names def test_get_httpx_client_creates_new_client(): @@ -110,6 +121,135 @@ def test_get_httpx_client_creates_fresh_each_time(): assert result2 == mock_client2 +# Conversation state management tests +def test_get_conversation_state_creates_new(): + """Test _get_conversation_state creates new state for unknown URL.""" + provider = A2AClientToolProvider() + + state = provider._get_conversation_state("http://new-agent.com") + + assert isinstance(state, ConversationState) + assert state.context_id is None + assert state.active_tasks == {} + assert "http://new-agent.com" in provider._conversation_states + + +def test_get_conversation_state_returns_existing(): + """Test _get_conversation_state returns existing state.""" + provider = A2AClientToolProvider() + existing_state = ConversationState(context_id="existing-context") + provider._conversation_states["http://agent.com"] = existing_state + + state = provider._get_conversation_state("http://agent.com") + + assert state is existing_state + assert state.context_id == "existing-context" + + +def test_update_conversation_state_stores_context_id(): + """Test _update_conversation_state stores context_id on first response.""" + provider = A2AClientToolProvider() + + provider._update_conversation_state("http://agent.com", context_id="ctx-123") + + state = provider._conversation_states["http://agent.com"] + assert state.context_id == "ctx-123" + + +def test_update_conversation_state_does_not_overwrite_context_id(): + """Test _update_conversation_state does not overwrite existing context_id.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState(context_id="original") + + provider._update_conversation_state("http://agent.com", context_id="new-context") + + state = provider._conversation_states["http://agent.com"] + assert state.context_id == "original" + + +def test_update_conversation_state_tracks_active_task(): + """Test _update_conversation_state tracks active tasks.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState(context_id="ctx-123") + + provider._update_conversation_state( + "http://agent.com", context_id="ctx-123", task_id="task-456", task_state="working" + ) + + state = provider._conversation_states["http://agent.com"] + assert "task-456" in state.active_tasks + assert state.active_tasks["task-456"].task_id == "task-456" + assert state.active_tasks["task-456"].state == "working" + + +def test_update_conversation_state_removes_terminal_task(): + """Test _update_conversation_state removes tasks in terminal states.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState( + context_id="ctx-123", + active_tasks={"task-456": ActiveTask(task_id="task-456", state="working", context_id="ctx-123")}, + ) + + provider._update_conversation_state("http://agent.com", task_id="task-456", task_state="completed") + + state = provider._conversation_states["http://agent.com"] + assert "task-456" not in state.active_tasks + + +def test_update_conversation_state_all_terminal_states(): + """Test all terminal task states are handled correctly.""" + for terminal_state in TERMINAL_TASK_STATES: + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState( + context_id="ctx-123", + active_tasks={"task-456": ActiveTask(task_id="task-456", state="working", context_id="ctx-123")}, + ) + + provider._update_conversation_state("http://agent.com", task_id="task-456", task_state=terminal_state) + + state = provider._conversation_states["http://agent.com"] + assert "task-456" not in state.active_tasks, f"Task should be removed for state: {terminal_state}" + + +def test_get_task_id_for_continuation_single_active(): + """Test _get_task_id_for_continuation returns task_id when exactly one active.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState( + context_id="ctx-123", + active_tasks={"task-456": ActiveTask(task_id="task-456", state="working", context_id="ctx-123")}, + ) + + task_id = provider._get_task_id_for_continuation("http://agent.com") + + assert task_id == "task-456" + + +def test_get_task_id_for_continuation_no_active(): + """Test _get_task_id_for_continuation returns None when no active tasks.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState(context_id="ctx-123") + + task_id = provider._get_task_id_for_continuation("http://agent.com") + + assert task_id is None + + +def test_get_task_id_for_continuation_multiple_active(): + """Test _get_task_id_for_continuation returns None when multiple active tasks.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://agent.com"] = ConversationState( + context_id="ctx-123", + active_tasks={ + "task-1": ActiveTask(task_id="task-1", state="working", context_id="ctx-123"), + "task-2": ActiveTask(task_id="task-2", state="submitted", context_id="ctx-123"), + }, + ) + + task_id = provider._get_task_id_for_continuation("http://agent.com") + + assert task_id is None + + @pytest.mark.asyncio @patch.object(A2AClientToolProvider, "_create_a2a_card_resolver") async def test_discover_agent_card_success(mock_create_resolver): @@ -322,13 +462,39 @@ async def test_send_message_with_message_id(): "response": {"result": "ok"}, "message_id": "test_id", "target_agent_url": "http://test.com", + "context_id": None, + "task_id": None, } mock_send_message.return_value = expected_result result = await provider.a2a_send_message("Hello", "http://test.com", "test_id") assert result == expected_result - mock_send_message.assert_called_once_with("Hello", "http://test.com", "test_id") + mock_send_message.assert_called_once_with("Hello", "http://test.com", "test_id", None, None) + + +@pytest.mark.asyncio +async def test_send_message_with_context_and_task_id(): + """Test a2a_send_message with explicit context_id and task_id.""" + provider = A2AClientToolProvider() + + with patch.object(provider, "_send_message") as mock_send_message: + expected_result = { + "status": "success", + "response": {"result": "ok"}, + "message_id": "test_id", + "target_agent_url": "http://test.com", + "context_id": "ctx-123", + "task_id": "task-456", + } + mock_send_message.return_value = expected_result + + result = await provider.a2a_send_message( + "Hello", "http://test.com", "test_id", context_id="ctx-123", task_id="task-456" + ) + + assert result == expected_result + mock_send_message.assert_called_once_with("Hello", "http://test.com", "test_id", "ctx-123", "task-456") @pytest.mark.asyncio @@ -342,13 +508,15 @@ async def test_send_message_without_message_id(): "response": {"result": "ok"}, "message_id": "auto_generated", "target_agent_url": "http://test.com", + "context_id": None, + "task_id": None, } mock_send_message.return_value = expected_result result = await provider.a2a_send_message("Hello", "http://test.com") assert result == expected_result - mock_send_message.assert_called_once_with("Hello", "http://test.com", None) + mock_send_message.assert_called_once_with("Hello", "http://test.com", None, None, None) @pytest.mark.asyncio @@ -378,6 +546,8 @@ async def test_send_message_success(mock_ensure, mock_factory, mock_discover, mo # Mock client response - simulate Message response mock_response = Mock(spec=Message) mock_response.model_dump.return_value = {"result": "success"} + mock_response.context_id = "response-ctx" + mock_response.task_id = "response-task" async def mock_send_message_iter(message): yield mock_response @@ -391,6 +561,8 @@ async def mock_send_message_iter(message): "response": {"result": "success"}, "message_id": "message_id_123", "target_agent_url": "http://test.com", + "context_id": "response-ctx", + "task_id": "response-task", } assert result == expected mock_ensure.assert_called_once() @@ -398,6 +570,56 @@ async def mock_send_message_iter(message): mock_client_factory.create.assert_called_once_with(mock_agent_card) +@pytest.mark.asyncio +@patch("strands_tools.a2a_client.uuid4") +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_send_message_uses_persisted_context_id(mock_ensure, mock_factory, mock_discover, mock_uuid): + """Test _send_message uses persisted context_id when not explicitly provided.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://test.com"] = ConversationState(context_id="persisted-ctx") + + # Mock UUID generation + mock_message_uuid = Mock() + mock_message_uuid.hex = "message_id_123" + mock_uuid.return_value = mock_message_uuid + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock client response + mock_response = Mock(spec=Message) + mock_response.model_dump.return_value = {"result": "success"} + mock_response.context_id = None # Server doesn't return context_id this time + mock_response.task_id = None + + # Capture the message that was sent + sent_messages = [] + + async def mock_send_message_iter(message): + sent_messages.append(message) + yield mock_response + + mock_client.send_message = mock_send_message_iter + + result = await provider._send_message("Hello world", "http://test.com", None) + + # Verify the persisted context_id was used in the request + assert len(sent_messages) == 1 + assert sent_messages[0].context_id == "persisted-ctx" + + # Verify the response includes the context_id + assert result["context_id"] == "persisted-ctx" + + @pytest.mark.asyncio @patch.object(A2AClientToolProvider, "_discover_agent_card") @patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") @@ -413,6 +635,8 @@ async def test_send_message_error(mock_ensure, mock_discover): "error": "Connection failed", "message_id": "test_id", "target_agent_url": "http://test.com", + "context_id": None, + "task_id": None, } assert result == expected @@ -506,6 +730,10 @@ async def test_send_message_task_response(mock_ensure, mock_factory, mock_discov # Mock client response - simulate (Task, UpdateEvent) tuple response mock_task = Mock() mock_task.model_dump.return_value = {"task_id": "123", "status": "completed"} + mock_task.context_id = "task-ctx" + mock_task.id = "task-123" + mock_task.status = Mock() + mock_task.status.state = "working" mock_update_event = Mock() mock_update_event.model_dump.return_value = {"event": "finished"} @@ -521,6 +749,8 @@ async def mock_send_message_iter(message): "response": {"task": {"task_id": "123", "status": "completed"}, "update": {"event": "finished"}}, "message_id": "message_id_123", "target_agent_url": "http://test.com", + "context_id": "task-ctx", + "task_id": "task-123", } assert result == expected mock_ensure.assert_called_once() @@ -555,6 +785,9 @@ async def test_send_message_task_response_no_update(mock_ensure, mock_factory, m # Mock client response - simulate (Task, None) tuple response mock_task = Mock() mock_task.model_dump.return_value = {"task_id": "123", "status": "completed"} + mock_task.context_id = "task-ctx" + mock_task.id = "task-123" + mock_task.status = None async def mock_send_message_iter(message): yield (mock_task, None) @@ -568,5 +801,531 @@ async def mock_send_message_iter(message): "response": {"task": {"task_id": "123", "status": "completed"}, "update": None}, "message_id": "message_id_123", "target_agent_url": "http://test.com", + "context_id": "task-ctx", + "task_id": "task-123", + } + assert result == expected + + +# Tests for new conversation state tools +@pytest.mark.asyncio +async def test_get_conversation_state_tool(): + """Test a2a_get_conversation_state returns correct state.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://test.com"] = ConversationState( + context_id="ctx-123", + active_tasks={"task-456": ActiveTask(task_id="task-456", state="working", context_id="ctx-123")}, + ) + + result = await provider.a2a_get_conversation_state("http://test.com") + + expected = { + "status": "success", + "context_id": "ctx-123", + "active_tasks": [{"task_id": "task-456", "state": "working", "context_id": "ctx-123"}], + "target_agent_url": "http://test.com", + } + assert result == expected + + +@pytest.mark.asyncio +async def test_get_conversation_state_tool_empty(): + """Test a2a_get_conversation_state for unknown agent.""" + provider = A2AClientToolProvider() + + result = await provider.a2a_get_conversation_state("http://unknown.com") + + expected = { + "status": "success", + "context_id": None, + "active_tasks": [], + "target_agent_url": "http://unknown.com", + } + assert result == expected + + +@pytest.mark.asyncio +async def test_clear_conversation_state_tool(): + """Test a2a_clear_conversation_state removes state.""" + provider = A2AClientToolProvider() + provider._conversation_states["http://test.com"] = ConversationState(context_id="ctx-123") + + result = await provider.a2a_clear_conversation_state("http://test.com") + + assert result == {"status": "success", "target_agent_url": "http://test.com"} + assert "http://test.com" not in provider._conversation_states + + +@pytest.mark.asyncio +async def test_clear_conversation_state_tool_nonexistent(): + """Test a2a_clear_conversation_state handles nonexistent agent.""" + provider = A2AClientToolProvider() + + result = await provider.a2a_clear_conversation_state("http://unknown.com") + + assert result == {"status": "success", "target_agent_url": "http://unknown.com"} + + +# Tests for a2a_get_task tool +@pytest.mark.asyncio +async def test_get_task_tool(): + """Test a2a_get_task calls internal implementation.""" + provider = A2AClientToolProvider() + + with patch.object(provider, "_get_task") as mock_get_task: + expected_result = { + "status": "success", + "task": {"id": "task-123", "status": {"state": "working"}}, + "task_id": "task-123", + "task_state": "working", + "context_id": "ctx-123", + "target_agent_url": "http://test.com", + } + mock_get_task.return_value = expected_result + + result = await provider.a2a_get_task("http://test.com", "task-123") + + assert result == expected_result + mock_get_task.assert_called_once_with("http://test.com", "task-123", None) + + +@pytest.mark.asyncio +async def test_get_task_tool_with_history_length(): + """Test a2a_get_task with history_length parameter.""" + provider = A2AClientToolProvider() + + with patch.object(provider, "_get_task") as mock_get_task: + expected_result = { + "status": "success", + "task": {"id": "task-123", "status": {"state": "working"}}, + "task_id": "task-123", + "task_state": "working", + "context_id": "ctx-123", + "target_agent_url": "http://test.com", + } + mock_get_task.return_value = expected_result + + result = await provider.a2a_get_task("http://test.com", "task-123", history_length=10) + + assert result == expected_result + mock_get_task.assert_called_once_with("http://test.com", "task-123", 10) + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_get_task_success(mock_ensure, mock_factory, mock_discover): + """Test _get_task successfully retrieves task.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock task response + mock_task = Mock() + mock_task.model_dump.return_value = {"id": "task-123", "status": {"state": "working"}} + mock_task.id = "task-123" + mock_task.status = Mock() + mock_task.status.state = "working" + mock_task.context_id = "ctx-123" + + mock_client.get_task = AsyncMock(return_value=mock_task) + + result = await provider._get_task("http://test.com", "task-123") + + expected = { + "status": "success", + "task": {"id": "task-123", "status": {"state": "working"}}, + "task_id": "task-123", + "task_state": "working", + "context_id": "ctx-123", + "target_agent_url": "http://test.com", + } + assert result == expected + mock_ensure.assert_called_once() + mock_discover.assert_called_once_with("http://test.com") + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_get_task_with_history_length(mock_ensure, mock_factory, mock_discover): + """Test _get_task with history_length parameter.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock task response + mock_task = Mock() + mock_task.model_dump.return_value = {"id": "task-123", "status": {"state": "completed"}} + mock_task.id = "task-123" + mock_task.status = Mock() + mock_task.status.state = "completed" + mock_task.context_id = "ctx-123" + + mock_client.get_task = AsyncMock(return_value=mock_task) + + result = await provider._get_task("http://test.com", "task-123", history_length=5) + + # Verify TaskQueryParams was called with history_length + assert mock_client.get_task.called + assert result["status"] == "success" + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_get_task_not_found(mock_ensure, mock_factory, mock_discover): + """Test _get_task when task is not found.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock get_task to raise exception + mock_client.get_task = AsyncMock(side_effect=Exception("Task not found")) + + result = await provider._get_task("http://test.com", "task-123") + + expected = { + "status": "error", + "error": "Task not found: Task not found", + "error_type": "Exception", + "task_id": "task-123", + "target_agent_url": "http://test.com", + } + assert result == expected + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_get_task_updates_conversation_state(mock_ensure, mock_factory, mock_discover): + """Test _get_task updates conversation state.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock task response + mock_task = Mock() + mock_task.model_dump.return_value = {"id": "task-123", "status": {"state": "working"}} + mock_task.id = "task-123" + mock_task.status = Mock() + mock_task.status.state = "working" + mock_task.context_id = "ctx-456" + + mock_client.get_task = AsyncMock(return_value=mock_task) + + await provider._get_task("http://test.com", "task-123") + + # Verify conversation state was updated + state = provider._conversation_states.get("http://test.com") + assert state is not None + assert state.context_id == "ctx-456" + assert "task-123" in state.active_tasks + + +# Tests for a2a_cancel_task tool +@pytest.mark.asyncio +async def test_cancel_task_tool(): + """Test a2a_cancel_task calls internal implementation.""" + provider = A2AClientToolProvider() + + with patch.object(provider, "_cancel_task") as mock_cancel_task: + expected_result = { + "status": "success", + "task": {"id": "task-123", "status": {"state": "canceled"}}, + "task_id": "task-123", + "task_state": "canceled", + "target_agent_url": "http://test.com", + } + mock_cancel_task.return_value = expected_result + + result = await provider.a2a_cancel_task("http://test.com", "task-123") + + assert result == expected_result + mock_cancel_task.assert_called_once_with("http://test.com", "task-123", None) + + +@pytest.mark.asyncio +async def test_cancel_task_tool_with_context_id(): + """Test a2a_cancel_task with context_id parameter.""" + provider = A2AClientToolProvider() + + with patch.object(provider, "_cancel_task") as mock_cancel_task: + expected_result = { + "status": "success", + "task": {"id": "task-123", "status": {"state": "canceled"}}, + "task_id": "task-123", + "task_state": "canceled", + "target_agent_url": "http://test.com", + } + mock_cancel_task.return_value = expected_result + + result = await provider.a2a_cancel_task("http://test.com", "task-123", context_id="ctx-123") + + assert result == expected_result + mock_cancel_task.assert_called_once_with("http://test.com", "task-123", "ctx-123") + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_cancel_task_success(mock_ensure, mock_factory, mock_discover): + """Test _cancel_task successfully cancels task.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock get_task response (task is active) + mock_current_task = Mock() + mock_current_task.status = Mock() + mock_current_task.status.state = "working" + mock_current_task.context_id = "ctx-123" + + # Mock cancel_task response + mock_canceled_task = Mock() + mock_canceled_task.model_dump.return_value = {"id": "task-123", "status": {"state": "canceled"}} + mock_canceled_task.id = "task-123" + mock_canceled_task.status = Mock() + mock_canceled_task.status.state = "canceled" + mock_canceled_task.context_id = "ctx-123" + + mock_client.get_task = AsyncMock(return_value=mock_current_task) + mock_client.cancel_task = AsyncMock(return_value=mock_canceled_task) + + result = await provider._cancel_task("http://test.com", "task-123") + + expected = { + "status": "success", + "task": {"id": "task-123", "status": {"state": "canceled"}}, + "task_id": "task-123", + "target_agent_url": "http://test.com", + "task_state": "canceled", + } + assert result == expected + mock_ensure.assert_called_once() + mock_discover.assert_called_once_with("http://test.com") + mock_client.cancel_task.assert_called_once() + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_cancel_task_already_terminal(mock_ensure, mock_factory, mock_discover): + """Test _cancel_task when task is already in terminal state.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock get_task response (task is already completed) + mock_current_task = Mock() + mock_current_task.status = Mock() + mock_current_task.status.state = "completed" + + mock_client.get_task = AsyncMock(return_value=mock_current_task) + + result = await provider._cancel_task("http://test.com", "task-123") + + expected = { + "status": "error", + "error": "Task cannot be canceled - current state: completed", + "task_id": "task-123", + "target_agent_url": "http://test.com", + "task_state": "completed", + } + assert result == expected + # cancel_task should NOT be called + mock_client.cancel_task.assert_not_called() + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_cancel_task_context_id_mismatch(mock_ensure, mock_factory, mock_discover): + """Test _cancel_task with context_id mismatch.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock get_task response (task has different context_id) + mock_current_task = Mock() + mock_current_task.status = Mock() + mock_current_task.status.state = "working" + mock_current_task.context_id = "ctx-different" + + mock_client.get_task = AsyncMock(return_value=mock_current_task) + + result = await provider._cancel_task("http://test.com", "task-123", context_id="ctx-expected") + + expected = { + "status": "error", + "error": "Context ID mismatch: expected ctx-expected, got ctx-different", + "task_id": "task-123", + "target_agent_url": "http://test.com", + } + assert result == expected + # cancel_task should NOT be called + mock_client.cancel_task.assert_not_called() + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_cancel_task_not_found(mock_ensure, mock_factory, mock_discover): + """Test _cancel_task when task is not found.""" + provider = A2AClientToolProvider() + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock get_task to raise exception + mock_client.get_task = AsyncMock(side_effect=Exception("Task not found")) + + result = await provider._cancel_task("http://test.com", "task-123") + + expected = { + "status": "error", + "error": "Task not found or inaccessible: Task not found", + "task_id": "task-123", + "target_agent_url": "http://test.com", + } + assert result == expected + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_get_client_factory") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_cancel_task_updates_conversation_state(mock_ensure, mock_factory, mock_discover): + """Test _cancel_task updates conversation state when task reaches terminal state.""" + provider = A2AClientToolProvider() + + # Set up initial active task + provider._conversation_states["http://test.com"] = ConversationState( + context_id="ctx-123", + active_tasks={"task-123": ActiveTask(task_id="task-123", state="working", context_id="ctx-123")}, + ) + + # Mock agent card + mock_agent_card = Mock() + mock_discover.return_value = mock_agent_card + + # Mock ClientFactory and Client + mock_client_factory = Mock() + mock_client = Mock() + mock_factory.return_value = mock_client_factory + mock_client_factory.create.return_value = mock_client + + # Mock get_task response (task is active) + mock_current_task = Mock() + mock_current_task.status = Mock() + mock_current_task.status.state = "working" + mock_current_task.context_id = "ctx-123" + + # Mock cancel_task response (task is now canceled) + mock_canceled_task = Mock() + mock_canceled_task.model_dump.return_value = {"id": "task-123", "status": {"state": "canceled"}} + mock_canceled_task.id = "task-123" + mock_canceled_task.status = Mock() + mock_canceled_task.status.state = "canceled" + mock_canceled_task.context_id = "ctx-123" + + mock_client.get_task = AsyncMock(return_value=mock_current_task) + mock_client.cancel_task = AsyncMock(return_value=mock_canceled_task) + + await provider._cancel_task("http://test.com", "task-123") + + # Verify task was removed from active tasks (terminal state) + state = provider._conversation_states["http://test.com"] + assert "task-123" not in state.active_tasks + + +@pytest.mark.asyncio +@patch.object(A2AClientToolProvider, "_discover_agent_card") +@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents") +async def test_cancel_task_general_error(mock_ensure, mock_discover): + """Test _cancel_task handles general errors.""" + provider = A2AClientToolProvider() + + mock_discover.side_effect = Exception("Network error") + + result = await provider._cancel_task("http://test.com", "task-123") + + expected = { + "status": "error", + "error": "Network error", + "error_type": "Exception", + "task_id": "task-123", + "target_agent_url": "http://test.com", } assert result == expected