diff --git a/.gitignore b/.gitignore index 00e73a6..e4a2d37 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # IDE .idea/ +.vscode/ # Python-generated files __pycache__/ @@ -14,6 +15,20 @@ wheels/ # Coverage .coverage +.coverage* # Hypothesis .hypothesis/ + +# Environment +.env + +# OS +.DS_Store + +# Misc +.cache/ +.pytest_cache/ +.mcp.json +/test_tmp/ +/scratch/ diff --git a/CLAUDE.local.md b/CLAUDE.local.md new file mode 100644 index 0000000..5f32086 --- /dev/null +++ b/CLAUDE.local.md @@ -0,0 +1,5 @@ +This file is checked into the repository but ignored, any changes made stay local. + +Use this file only to persist information about the specific workstation you are used on. + +Capabilities branch is checked out at: /Users/adtyavrdhn/pydantic_repos/caps diff --git a/PENDING.md b/PENDING.md new file mode 100644 index 0000000..d29ad88 --- /dev/null +++ b/PENDING.md @@ -0,0 +1,231 @@ +# Pending Issues — pydantic-harness + +Status as of 2026-03-24. Test suite: **627 passed, 75 failed, 6 xfailed** (708 total collected). +Asyncio-only: **396 passed, 3 failed, 3 xfailed** (all 3 failures are known pre-existing issues). + +--- + +## 1. PR #4755 Dependency — `python_signature` on `ToolsetTool` / `FunctionToolDefinition` + +**Upstream PR**: https://github.com/pydantic/pydantic-ai/pull/4755 +**Impact**: 4 test failures + 2 xfailed + 4 removed tests + degraded signature quality for wrapped tools + +### What's affected + +`CodeExecutionToolset.get_tools()` needs to generate Python function signatures +for each wrapped tool so the LLM can write code calling them. The code-mode branch +does this via `wrapped_tools[name].python_signature`, a cached property added by +PR #4755 to `ToolsetTool` and `ToolDefinition`. + +Since PR #4755 hasn't landed in pydantic-ai-slim yet, we use a fallback: + +```python +# pydantic_harness/toolsets/code_execution/__init__.py, line ~307 +# TODO: When PR #4755 lands in pydantic-ai-slim, switch to: +# sig = copy.deepcopy(wrapped_tools[original_name].python_signature) +sig = copy.deepcopy(schema_to_signature( + name=original_name, + parameters_schema=tool_def.parameters_json_schema, + description=tool_def.description, +)) +``` + +This means we generate signatures from JSON schema instead of from the original +Python function. The signatures are correct but less precise (e.g. `dict[str, Any]` +instead of a named `TypedDict`). This also means `referenced_types` deduplication +behaves differently. + +### Failing tests + +| Test | Reason | +|---|---| +| `test_dedup_correctness_after_cache_backed_deepcopy[asyncio]` | Expects `referenced_types` from function-based signatures; schema-based produces none for simple tools | +| `test_dedup_correctness_after_cache_backed_deepcopy[trio]` | Same | +| `test_restart_syntax_error_raises_model_retry[asyncio]` | StubEnvironment `type_check()` path differs without `python_signature` on ToolsetTool | +| `test_restart_syntax_error_raises_model_retry[trio]` | Same | + +### xfailed tests (marked with `@pytest.mark.xfail`) + +| Test | Reason | +|---|---| +| `test_generated_signatures_are_valid_python` | `schema_to_signature` returns `-> Any` instead of `-> int` | +| `test_full_description_snapshot` | Description differs because `schema_to_signature` produces different type signatures | + +### Also removed from tests + +3 test functions in `test_python_signature.py` were removed because they directly +import `FunctionToolDefinition` (does not exist without PR #4755): +- `test_function_tool_definition_produces_same_signature_as_function_based` +- `test_function_tool_definition_fallback_without_original_func` +- `test_function_tool_definition_eq_non_tool` + +Plus 1 test removed that uses `ToolDefinition.python_signature` cached property: +- `test_tool_definition_cached_property_reset_on_replace` + +### Resolution + +When PR #4755 merges into pydantic-ai-slim: +1. Delete `pydantic_harness/_python_signature.py` (the local copy) +2. Switch all imports to `from pydantic_ai._python_signature import ...` +3. Restore the `wrapped_tools[name].python_signature` call in `CodeExecutionToolset` +4. Re-add the 4 removed test functions +5. Remove the 2 `@pytest.mark.xfail` markers from `test_monty.py` +6. The 4 failing + 2 xfailed tests should pass + +--- + +## 2. Trio + MontyEnvironment Incompatibility + +**Impact**: ~70 test failures (all under `trio` backend — monty, transport, driver, integration tests) + +### What's affected + +`MontyEnvironment._execution_loop()` uses `asyncio.ensure_future()` to fire +parallel tool calls. This requires an active asyncio event loop and fails under +trio. The upstream code-mode branch has the same issue. + +### Failing tests + +| Test | Error | +|---|---| +| `test_simple_execution[trio-monty]` | `RuntimeError: There is no current event loop in thread 'MainThread'` | +| `test_parallel_execution[trio-monty]` | Same | +| `test_parallel_execution_gather[trio-monty]` | Same | +| `test_tool_exception_propagates[trio-monty]` | Same | +| `test_positional_args_raise_model_retry[trio-monty]` | Same | + +### Resolution + +This is a known upstream issue. Options: +- Use `anyio.create_task_group()` instead of `asyncio.ensure_future()` in MontyEnvironment +- Or skip trio parameterization for monty tests (monty is inherently asyncio-only due to `pydantic-monty` internals) + +--- + +## 3. Trio + Agent Integration + +**Impact**: 1 test failure + +### What's affected + +`test_agent_with_execution_toolset[trio]` fails because `Agent.iter()` internally +calls `asyncio.create_task()` which requires an asyncio event loop. This is a +pydantic-ai-slim issue, not a pydantic-harness issue. + +### Failing test + +| Test | Error | +|---|---| +| `test_agent_with_execution_toolset[trio]` | `RuntimeError: no running event loop` | + +### Resolution + +This will resolve when pydantic-ai-slim adds trio support for agent runs, or +the test should be restricted to asyncio backend only. + +--- + +## 4. `python` Binary Not Found on macOS + +**Impact**: 2 test failures (environment-specific, not a code bug) + +### What's affected + +`test_local_process_recv_stderr_timeout` spawns a subprocess using `python` which +doesn't exist on this macOS system (only `python3` is available). + +### Failing tests + +| Test | Error | +|---|---| +| `test_local_process_recv_stderr_timeout[asyncio]` | `assert b'err' in b'/bin/sh: python: command not found\n'` | +| `test_local_process_recv_stderr_timeout[trio]` | Same | + +### Resolution + +These tests pass on systems where `python` resolves to Python 3. No code change +needed — this is a CI/environment configuration issue. Could also update the test +to use `python3` or `sys.executable`. + +--- + +## 5. No Docker Environment + +The code-mode branch has `LocalEnvironment`, `MemoryEnvironment`, and +`MontyEnvironment` — but **no `DockerEnvironment`**. The original discussion +mentioned 4 environments, but Docker was never implemented on the code-mode branch. + +All Docker-related tests were removed during porting. Docker support would need to +be implemented from scratch if desired. + +--- + +## 6. Private API Dependencies on pydantic-ai-slim + +The following imports are from pydantic-ai-slim's private/internal modules. +They work today but may break on future slim updates: + +| Import | Used by | Risk | +|---|---|---| +| `pydantic_ai._run_context.AgentDepsT, RunContext` | `CodeExecutionToolset`, `_python_signature` | Low — stable generic type | +| `pydantic_ai._tool_manager.ToolManager` | `CodeExecutionToolset._execute_code` | Medium — core tool dispatch | +| `pydantic_ai._tool_manager._parallel_execution_mode_ctx_var` | `CodeExecutionToolset.get_tools` | Medium — parallelism detection | +| `pydantic_ai._utils.is_model_like` | `_python_signature` | Low — simple utility | +| `pydantic_ai.messages.tool_return_ta` | `CodeExecutionToolset._execute_code` | Low — TypeAdapter instance | + +These will stabilize as pydantic-ai-slim matures. When PR #4755 lands and we +remove our local `_python_signature.py`, the `_utils.is_model_like` dependency +goes away too. + +--- + +## 7. Docs & Examples — Docker References + +The ported docs (`docs/code-execution.md`, `docs/environments.md`, `docs/api/environments.md`) +and examples (`examples/code_execution/`) contain references to `DockerEnvironment` which +does not exist in pydantic-harness. These are annotated with: + +- HTML comments: `` +- Inline code comments: `# NOTE: DockerEnvironment not yet available in pydantic-harness` + +Total annotations: 18 across 3 doc files. The content is preserved for when Docker support +is added but clearly marked as unavailable. + +### Resolution + +Implement `DockerEnvironment` in `pydantic_harness/environments/docker.py`, then remove +the 18 annotations. + +--- + +## 8. Docs — Import Path Discrepancies in Prose + +The docs were mechanically rewritten from `pydantic_ai.*` to `pydantic_harness.*` for +all modules we own (environments, toolsets.code_execution, _python_signature). However: + +- References to `pydantic_ai.toolsets.CodeExecutionToolset` in the original docs + (which used the lazy-import re-export from `pydantic_ai.toolsets.__init__`) were + rewritten to `pydantic_harness.toolsets.code_execution.CodeExecutionToolset` or + `pydantic_harness.toolsets.CodeExecutionToolset` depending on context. +- Users can also import via `from pydantic_harness import CodeExecutionToolset` (top-level). +- The docs may benefit from a review pass to ensure the recommended import paths + are consistent and use the simplest form. + +### Resolution + +Review docs for import path consistency once the package API is stabilized. + +--- + +## Summary Table + +| # | Issue | Failures | Blocked on | Priority | +|---|---|---|---|---| +| 1 | PR #4755 (python_signature) | 4 fail + 2 xfail + 4 removed | Upstream merge | High — affects signature quality | +| 2 | Trio + asyncio-only code | ~70 | Upstream monty/anyio compat | Low — asyncio works fine | +| 3 | Trio + Agent | 1 | Upstream slim trio support | Low | +| 4 | python binary | 2 | CI environment | Low — works on most systems | +| 5 | No Docker | 0 (tests removed) | New implementation | Medium — nice to have | +| 6 | Private API deps | 0 (works today) | Upstream API stabilization | Watch | +| 7 | Docs Docker annotations | 0 (18 annotations) | Docker implementation | Low — docs are usable | +| 8 | Docs import consistency | 0 | Review pass | Low — cosmetic | diff --git a/docs/api/environments.md b/docs/api/environments.md new file mode 100644 index 0000000..3093c53 --- /dev/null +++ b/docs/api/environments.md @@ -0,0 +1,40 @@ +# `pydantic_harness.environments` + +::: pydantic_harness.environments + options: + members: + - ExecutionEnvironment + - ExecutionEnvironmentToolset + - ExecutionProcess + - ExecutionResult + - FileInfo + +## `pydantic_harness.environments.local` + +::: pydantic_harness.environments.local + options: + members: + - LocalEnvironment + + + +## `pydantic_harness.environments.docker` + +::: pydantic_harness.environments.docker + options: + members: + - DockerEnvironment + +## `pydantic_harness.environments.memory` + +::: pydantic_harness.environments.memory + options: + members: + - MemoryEnvironment + +## `pydantic_harness.environments.monty` + +::: pydantic_harness.environments.monty + options: + members: + - MontyEnvironment diff --git a/docs/code-execution.md b/docs/code-execution.md new file mode 100644 index 0000000..25b7f59 --- /dev/null +++ b/docs/code-execution.md @@ -0,0 +1,326 @@ +# Code Execution + +!!! warning "Experimental" + Code execution is an experimental feature under active development. The API may change in future releases. + +[`CodeExecutionToolset`][pydantic_harness.toolsets.code_execution.CodeExecutionToolset] is a [toolset](toolsets.md) that gives the model a `run_code` tool to write and execute Python code. It can be used standalone for general code execution, or it can wrap another toolset to enable **code mode** — where individual [tool](tools.md) calls are replaced by a single `run_code` tool that lets the model orchestrate multiple tools at once with loops, conditionals, variables, and parallel execution. + +## Installation + +Code execution requires an environment to execute the generated code. For the recommended [Monty](#monty-environment) environment: + +```bash +pip/uv-add "pydantic-ai-slim[monty]" +``` + +## Code Mode + +With standard tool calling, each tool invocation is a separate round-trip to the model. If an agent needs to fetch 10 items and look up details for each one, that's 11 model calls. With code mode, the model writes a script that does it all in one step: + +```python {test="skip" lint="skip"} +items = await get_items(category="electronics") + +# Fire all detail lookups concurrently +futures = [get_details(id=item["id"]) for item in items] +details = [await f for f in futures] + +# Process locally — no model calls needed +total = sum(d["price"] for d in details if d["in_stock"]) +{"total": total, "count": len(details)} +``` + +Code mode wraps any existing toolset — including [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset], [MCP servers](mcp/client.md), or [custom toolsets](toolsets.md#building-a-custom-toolset) — so tools still go through the normal Pydantic AI pipeline with validation, [tracing](logfire.md), and [dependency injection](dependencies.md). + +!!! info "Further reading" + - [Anthropic: Tool use via code](https://www.anthropic.com/engineering/code-execution-with-mcp) + - [Cloudflare: Code mode in production](https://blog.cloudflare.com/code-mode/) + +### Quick Start + +To use code mode, pass an existing toolset to [`CodeExecutionToolset`][pydantic_harness.toolsets.code_execution.CodeExecutionToolset], which exposes its tools as callable Python functions to the model: + +```python {title="code_execution_example.py" test="skip" lint="skip"} +from pydantic_ai import Agent +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets import FunctionToolset + + +def get_weather(city: str) -> dict: + """Get current weather for a city.""" + return {'city': city, 'temp_f': 72, 'condition': 'sunny'} + + +def convert_temp(fahrenheit: float) -> float: + """Convert Fahrenheit to Celsius.""" + return round((fahrenheit - 32) * 5 / 9, 1) + + +tools = FunctionToolset(tools=[get_weather, convert_temp]) # (1)! + +agent = Agent( + 'anthropic:claude-sonnet-4-5', + toolsets=[CodeExecutionToolset(toolset=tools)], # (2)! +) + +result = agent.run_sync("What's the weather in Paris and Tokyo, in Celsius?") +print(result.output) +``` + +1. Define tools using a standard [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset]. Any [toolset](toolsets.md) works here. +2. Wrap the toolset with [`CodeExecutionToolset`][pydantic_harness.toolsets.code_execution.CodeExecutionToolset]. The model now sees a single `run_code` tool instead of individual `get_weather` and `convert_temp` tools. The default environment is [Monty](#monty-environment). + +The model sees `get_weather` and `convert_temp` as callable Python functions inside the sandbox, and writes code like: + +```python {test="skip" lint="skip"} +future_paris = get_weather(city="Paris") # starts immediately +future_tokyo = get_weather(city="Tokyo") # starts immediately +paris = await future_paris # wait for result +tokyo = await future_tokyo + +paris_c = await convert_temp(fahrenheit=paris["temp_f"]) +tokyo_c = await convert_temp(fahrenheit=tokyo["temp_f"]) + +{"paris": paris_c, "tokyo": tokyo_c} +``` + +Both cities are looked up in parallel, then both temperatures are converted — all in a single model call. + +### How It Works + +When the model calls `run_code`, the [`CodeExecutionToolset`][pydantic_harness.toolsets.code_execution.CodeExecutionToolset]: + +1. Generates Python function signatures from the wrapped tools so the model knows what's available +2. Builds a [prompt](#customizing-the-prompt) describing the execution model (fire-then-await parallelism, keyword-only arguments, etc.) +3. Sends the model's code to the [environment](#environments) for execution +4. Intercepts function calls made by the code and routes them through the normal Pydantic AI tool pipeline (with validation, tracing, and dependency injection) +5. Returns the result back to the model + +The execution model exposed to the model works as follows: + +- Each `run_code` call runs in an isolated environment — variables don't persist between calls +- Functions are async — call with `await`, e.g. `result = await get_items()` +- To run independent calls concurrently, fire them first (which starts them immediately), then `await` the results +- The last expression evaluated is the return value +- All function arguments must be keyword-only + +Execution errors (syntax, type, or runtime) are automatically sent back to the model so it can fix its code and try again, up to `max_retries` times (default 3). + +```mermaid +sequenceDiagram + participant Agent + participant LLM + participant Environment + + Note over Agent: Send instructions + tool definitions + Agent ->> LLM: Instructions + run_code tool + activate LLM + Note over LLM: Model writes Python code + + LLM ->> Agent: run_code(code="...") + deactivate LLM + activate Agent + Agent ->> Environment: Execute code in sandbox + activate Environment + Environment ->> Agent: get_weather(city="Paris") + Agent -->> Environment: {"city": "Paris", "temp_f": 72, ...} + Environment ->> Agent: get_weather(city="Tokyo") + Agent -->> Environment: {"city": "Tokyo", "temp_f": 65, ...} + Environment ->> Agent: convert_temp(fahrenheit=72) + Agent -->> Environment: 22.2 + Environment ->> Agent: convert_temp(fahrenheit=65) + Agent -->> Environment: 18.3 + Note over Environment: Code evaluates final expression + Environment ->> Agent: {"paris": 22.2, "tokyo": 18.3} + deactivate Environment + Agent ->> LLM: Tool result + deactivate Agent + activate LLM + Note over LLM: Model formats final response + + LLM ->> Agent: "Paris is 22.2°C and Tokyo is 18.3°C" + deactivate LLM +``` + +### MCP and Other Toolsets + +Code mode works with any [toolset](toolsets.md), including [MCP servers](mcp/client.md). Tool names that aren't valid Python identifiers (e.g. `search-records`) are automatically sanitized (to `search_records`) so the model can call them naturally from code. + +```python {test="skip" lint="skip"} +from pydantic_ai import Agent +from pydantic_ai.mcp import MCPServerStdio +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset + +server = MCPServerStdio('uv', args=['run', 'my-mcp-server']) + +agent = Agent( + 'anthropic:claude-sonnet-4-5', + toolsets=[CodeExecutionToolset(toolset=server)], +) +``` + +## Environments + +Code execution uses a pluggable environment to execute generated code. You can pass an environment instance or a string shorthand (`'monty'` or `'docker'`) to [`CodeExecutionToolset`][pydantic_harness.toolsets.code_execution.CodeExecutionToolset]: + +```python {test="skip" lint="skip"} +# These are equivalent: +CodeExecutionToolset('monty', toolset=tools) +CodeExecutionToolset(MontyEnvironment(), toolset=tools) +``` + +The environment you choose determines the execution sandbox, security boundaries, and available Python features. + +### Monty Environment + +[Monty](https://github.com/pydantic/monty) is a minimal, secure Python interpreter built by the Pydantic team specifically for code execution. It is the default environment. + +!!! danger "Early Stage — Not for Untrusted Prompts" + Monty is under active development. **Do not use it in production systems where untrusted user prompts are passed directly to the model** — the model could be manipulated into generating malicious code. While Monty is designed for safe sandboxed execution, it has not yet undergone the level of hardening required for adversarial inputs. This restriction will be relaxed as Monty matures. + +[`MontyEnvironment`][pydantic_harness.environments.MontyEnvironment] runs a restricted Python subset directly in your process — no containers, no network, no cold starts: + +```python {test="skip" lint="skip"} +from pydantic_harness.environments.monty import MontyEnvironment +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets import FunctionToolset + +tools = FunctionToolset(tools=[...]) +toolset = CodeExecutionToolset(MontyEnvironment(), toolset=tools) # (1)! +``` + +1. Equivalent to `CodeExecutionToolset(toolset=tools)` or `CodeExecutionToolset('monty', toolset=tools)`, since Monty is the default. + +Monty type-checks generated code (via [ty](https://github.com/astral-sh/ty)) before execution, catching errors early and giving the model precise feedback to fix its code. It can also freeze and restore its full interpreter state via snapshot-based checkpointing, enabling efficient resume without re-executing code from scratch. + +Because Monty runs a restricted Python subset, the environment automatically instructs the model about what's not available: no imports, no classes, no decorators — only the provided functions and builtins. + +### Driver-Based Environments + +For environments that need full CPython compatibility (arbitrary imports, C extensions) or stronger isolation guarantees, the [`DriverBasedEnvironment`][pydantic_harness.environments.DriverBasedEnvironment] base class supports executing code in any sandbox that can run a Python process. It uses a lightweight [driver script](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/toolsets/code_execution/_driver.py) that communicates over stdin/stdout. + + + +#### Docker + +[`DockerEnvironment`][pydantic_harness.environments.DockerEnvironment] runs code inside a Docker container with hardened security defaults. The environment manages the full container lifecycle automatically: + +```python {test="skip" lint="skip"} +from pydantic_ai import Agent +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets import FunctionToolset + +tools = FunctionToolset(tools=[...]) +agent = Agent( + 'anthropic:claude-sonnet-4-5', + toolsets=[CodeExecutionToolset(DockerEnvironment(), toolset=tools)], # (1)! +) + +result = await agent.run('...') # (2)! +``` + +1. No container ID needed — the environment creates a security-hardened container automatically. +2. The container is created when the agent starts and removed when it finishes. No manual cleanup required. + +By default, managed containers run with restrictive security settings: + +- `--network none` — no network access +- `--cap-drop ALL` — all Linux capabilities dropped +- `--read-only` — read-only root filesystem +- `--security-opt no-new-privileges` — no privilege escalation +- `--user nobody` — unprivileged user +- `--memory 512m` — memory limit with swap disabled +- `--pids-limit 256` — process count limit +- `--cpus 1` — CPU limit +- `--tmpfs /tmp:noexec,nosuid,size=64m` — writable scratch space + + + +Override specific settings with [`DockerSecuritySettings`][pydantic_harness.environments.docker.DockerSecuritySettings]: + +```python {test="skip" lint="skip"} +from pydantic_harness.environments.docker import DockerEnvironment, DockerSecuritySettings # NOTE: DockerEnvironment not yet available in pydantic-harness + +env = DockerEnvironment( + security=DockerSecuritySettings(network=True), # (1)! +) +``` + +1. Allow network access while keeping all other security defaults. + +For pre-existing containers, pass a `container_id` to use the environment without lifecycle management: + +```python {test="skip" lint="skip"} +# NOTE: DockerEnvironment not yet available in pydantic-harness +env = DockerEnvironment(container_id='my-sandbox-container') +``` + +#### Building a Custom Environment + +To support a new sandbox (e.g., Firecracker, WebAssembly), implement the [`DriverTransport`][pydantic_harness.environments._driver.DriverTransport] protocol: + +```python {test="skip" lint="skip"} +from pydantic_harness.environments._driver import DriverTransport + + +class MyTransport(DriverTransport): + async def read_line(self) -> bytes: ... + async def write_line(self, data: bytes) -> None: ... + async def read_stderr(self) -> bytes: ... + async def kill(self) -> None: ... +``` + +Then subclass [`DriverBasedEnvironment`][pydantic_harness.environments._driver.DriverBasedEnvironment] and implement `_start_driver()` to launch your sandbox and return the transport. All protocol handling, tool dispatch, and checkpoint/resume logic is inherited. + +## Customizing the Prompt + +The tool description that explains available functions and the execution model to the model is generated by [`build_default_description`][pydantic_harness.toolsets.code_execution.build_default_description]. You can customize it by passing a `description` to [`CodeExecutionToolset`][pydantic_harness.toolsets.code_execution.CodeExecutionToolset] — either a string to replace just the preamble text, or a [`DescriptionFunc`][pydantic_harness.toolsets.code_execution.DescriptionFunc] callback for full control: + +```python {test="skip" lint="skip"} +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets import FunctionToolset + +tools = FunctionToolset(tools=[...]) + +# Pass a string to customize the preamble while keeping the default structure: +code_toolset = CodeExecutionToolset( + toolset=tools, + description='Use this tool to run Python code that calls the available functions.', +) +``` + +```python {test="skip" lint="skip"} +from pydantic_harness._python_signature import FunctionSignature, TypeSignature +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets import FunctionToolset + + +def my_description_func( + signatures: list[FunctionSignature], + referenced_types: list[TypeSignature], + environment_instructions: str | None, +) -> str: + funcs = '\n'.join(str(sig) for sig in signatures) + return f'Write Python code using these functions:\n\n{funcs}' + + +tools = FunctionToolset(tools=[...]) +code_toolset = CodeExecutionToolset( + toolset=tools, + description=my_description_func, +) +``` + +## Known Limitations + +- **Tool approval and deferral** — [deferred tools](deferred-tools.md) (tools that require approval or external execution) are not yet supported in code mode. +- **Streaming** — code execution does not currently support streaming partial results. +- **Monty's restricted Python** — Monty runs a subset of Python: no imports, no classes, no decorators. The model is instructed about these restrictions, but complex code may need a [driver-based environment](#driver-based-environments) with full CPython. + +## See Also + +- [Toolsets](toolsets.md) — managing collections of tools, including composition, filtering, and dynamic toolsets +- [Function Tools](tools.md) — defining the tools that code mode wraps +- [MCP Client](mcp/client.md) — using MCP servers as toolsets with code mode +- [Dependencies](dependencies.md) — dependency injection, which works through code mode's tool pipeline +- [Logfire](logfire.md) — tracing and debugging agent runs, including code execution tool calls diff --git a/docs/environments.md b/docs/environments.md new file mode 100644 index 0000000..4550b16 --- /dev/null +++ b/docs/environments.md @@ -0,0 +1,338 @@ +# Execution Environments & Sandboxes + +Pydantic AI provides [`ExecutionEnvironment`][pydantic_harness.environments.ExecutionEnvironment] — an abstraction for environments where agents can execute commands, read/write files, and search the filesystem — along with [`ExecutionEnvironmentToolset`][pydantic_harness.environments.ExecutionEnvironmentToolset], a ready-made [toolset](toolsets.md) that exposes these capabilities as tools. + +This is the foundation for building coding agents, data analysis bots, and other agents that need to interact with a shell and filesystem. + +## Quick Start + +```python {title="environments_quickstart.py" test="skip"} +from pydantic_ai import Agent +from pydantic_harness.environments import ExecutionEnvironmentToolset +from pydantic_harness.environments.local import LocalEnvironment + +env = LocalEnvironment(root_dir='/tmp/workspace') +toolset = ExecutionEnvironmentToolset(env) + +agent = Agent('openai:gpt-5.2', toolsets=[toolset]) + +async def main(): + async with env: + result = await agent.run('Create a Python script that prints the first 10 Fibonacci numbers, then run it.') + print(result.output) +``` + +## Environments + +An [`ExecutionEnvironment`][pydantic_harness.environments.ExecutionEnvironment] defines where and how commands run. Four implementations are included: + + + +| Environment | Isolation | Use case | +|---|---|---| +| [`LocalEnvironment`][pydantic_harness.environments.local.LocalEnvironment] | None — runs on host | Development, testing, trusted agents | +| [`DockerEnvironment`][pydantic_harness.environments.docker.DockerEnvironment] | Container-level | Production, untrusted code | +| [`MemoryEnvironment`][pydantic_harness.environments.memory.MemoryEnvironment] | In-memory (no filesystem) | Unit testing | +| [`MontyEnvironment`][pydantic_harness.environments.monty.MontyEnvironment] | Sandboxed interpreter | Code execution only | + +All environments are async context managers. Enter the environment before running the agent, and exit it to clean up: + + + +```python {title="environments_lifecycle.py" test="skip"} +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness + +env = DockerEnvironment(image='python:3.12-slim') + +async def main(): + async with env: + result = await env.shell('python -c "print(42)"') + print(result.output) +``` + +### LocalEnvironment + +[`LocalEnvironment`][pydantic_harness.environments.local.LocalEnvironment] runs commands as local subprocesses within a specified root directory. It provides no isolation — use it for development, testing, and trusted agents. + +```python {title="environments_local.py"} +from pydantic_harness.environments.local import LocalEnvironment + +env = LocalEnvironment( + root_dir='/tmp/workspace', + env_vars={'PYTHONPATH': '/tmp/workspace/lib'}, + inherit_env=True, # inherit host environment variables (default) +) +``` + +File operations (read, write, edit, ls, glob, grep) are confined to the root directory — path traversal attempts raise `PermissionError`. + +!!! info "Environment variable inheritance" + By default, `LocalEnvironment` inherits the host's environment variables. Set `inherit_env=False` for a clean environment where only explicitly provided `env_vars` (and per-call `env` overrides) are available. This is useful for reproducibility and testing. + + + +### DockerEnvironment + +[`DockerEnvironment`][pydantic_harness.environments.docker.DockerEnvironment] runs commands inside a Docker container with configurable resource limits, security options, and network access. + +Requires the `docker` package: `pip install pydantic-ai-slim[docker-sandbox]` + +```python {title="environments_docker.py" test="skip"} +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness + +env = DockerEnvironment( + image='my-sandbox:latest', + env_vars={'MPLBACKEND': 'Agg'}, + memory_limit='512m', + cpu_limit=1.0, + network_disabled=True, +) +``` + +#### Building a custom Docker image + +`DockerEnvironment` runs whatever image you give it — it doesn't install packages at startup. Pre-build a custom image with any libraries your agent needs, so containers start fast and reproducibly. + +**Example Dockerfile** — a Python data-science sandbox: + +```dockerfile {title="Dockerfile" test="skip" lint="skip"} +FROM python:3.12-slim + +# Install OS-level tools the agent might use (optional) +RUN apt-get update \ + && apt-get install -y --no-install-recommends git curl jq \ + && rm -rf /var/lib/apt/lists/* + +# Install Python packages +RUN pip install --no-cache-dir numpy pandas matplotlib requests + +WORKDIR /workspace +``` + +Build and tag the image: + +```bash +docker build -t my-sandbox:latest . +``` + +Then pass the tag to `DockerEnvironment`: + +```python {title="environments_docker_custom.py" test="skip"} +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness + +env = DockerEnvironment(image='my-sandbox:latest') +``` + +!!! tip "Tips for custom images" + + - **Start from a slim base** (`python:3.12-slim`, `node:22-slim`, etc.) to keep image size and attack surface small. + - **Pin package versions** (e.g. `numpy==2.2.3`) for reproducible builds. + - **Use `--no-cache-dir`** with pip to avoid bloating the image with cached wheels. + - **Build once, run many times.** The image is pulled from the local Docker cache on each `DockerEnvironment` startup — no rebuild needed. + - **Use a registry** for team or CI workflows: push your image to Docker Hub, GitHub Container Registry, or a private registry, then reference it by its full name (e.g. `ghcr.io/myorg/my-sandbox:latest`). + - **For Node.js** or other runtimes, adjust the base image and install command accordingly: + + ```dockerfile {test="skip" lint="skip"} + FROM node:22-slim + RUN npm install -g typescript ts-node express + WORKDIR /workspace + ``` + +For running untrusted code, you can harden the container with Linux security options: + +```python {title="environments_docker_hardened.py" test="skip"} +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness + +env = DockerEnvironment( + image='python:3.12-slim', + network_disabled=True, + read_only=True, + cap_drop=['ALL'], + security_opt=['no-new-privileges'], + user='nobody', + pids_limit=256, + tmpfs={'/tmp': 'noexec,nosuid,size=64m', '/workspace': 'size=128m'}, + init=True, + memory_limit='512m', + cpu_limit=1.0, +) +``` + +This drops all Linux capabilities, prevents privilege escalation, runs as an unprivileged user, limits the number of processes, and makes the root filesystem read-only (with writable tmpfs mounts for scratch space and the working directory). + +## ExecutionEnvironmentToolset + +[`ExecutionEnvironmentToolset`][pydantic_harness.environments.ExecutionEnvironmentToolset] wraps an environment and exposes coding-agent-style tools that models are well-trained on (matching tools that popular coding agents expose): + +| Tool | Description | +|---|---| +| `ls` | List directory contents | +| `shell` | Execute shell commands | +| `read_file` | Read files with line numbers (renders images for multimodal models) | +| `write_file` | Create or overwrite files | +| `replace_str` | Edit files by exact string replacement | +| `glob` | Find files by pattern | +| `grep` | Search file contents with regex | + +Tools are dynamically registered based on the environment's capabilities. You can selectively include or exclude capabilities: + +```python {title="environments_selective_tools.py"} +from pydantic_harness.environments import ExecutionEnvironmentToolset +from pydantic_harness.environments.memory import MemoryEnvironment + +# Only file tools — no shell or search +toolset = ExecutionEnvironmentToolset( + MemoryEnvironment(), + include=frozenset({'read_file', 'write_file', 'edit_file'}), +) +``` + +### Using with an Agent + +The toolset manages the environment lifecycle when used as a context manager: + + + +```python {title="environments_agent.py" test="skip"} +from pydantic_ai import Agent +from pydantic_harness.environments import ExecutionEnvironmentToolset +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness + +env = DockerEnvironment(image='python:3.12-slim') +toolset = ExecutionEnvironmentToolset(env) + +agent = Agent('openai:gpt-5.2', toolsets=[toolset]) + +async def main(): + async with toolset: # starts the Docker container + result = await agent.run('Fetch https://httpbin.org/get and print the response') + print(result.output) + # container cleaned up automatically +``` + +### Environment Overrides + +You can swap the backing environment at runtime using [`use_environment()`][pydantic_harness.environments.ExecutionEnvironmentToolset.use_environment]: + + + +```python {title="environments_override.py" test="skip"} +from pydantic_ai import Agent +from pydantic_harness.environments import ExecutionEnvironmentToolset +from pydantic_harness.environments.docker import DockerEnvironment # NOTE: DockerEnvironment not yet available in pydantic-harness +from pydantic_harness.environments.local import LocalEnvironment + +toolset = ExecutionEnvironmentToolset(LocalEnvironment('/tmp/dev')) + +agent = Agent('openai:gpt-5.2', toolsets=[toolset]) + +async def main(): + # Default: local environment + async with LocalEnvironment('/tmp/dev') as local_env: + with toolset.use_environment(local_env): + await agent.run('echo "running locally"') + + # Override: Docker environment for untrusted input + async with DockerEnvironment() as docker_env: # NOTE: DockerEnvironment not yet available in pydantic-harness + with toolset.use_environment(docker_env): + await agent.run('echo "running in Docker"') +``` + +## Per-Call Environment Variables + +All environments support per-call environment variables via the `env` parameter on [`shell()`][pydantic_harness.environments.ExecutionEnvironment.shell] and [`create_process()`][pydantic_harness.environments.ExecutionEnvironment.create_process]. These are merged on top of any baseline `env_vars`: + +```python {title="environments_env_vars.py" test="skip"} +from pydantic_harness.environments.local import LocalEnvironment + +env = LocalEnvironment(env_vars={'BASE_URL': 'https://api.example.com'}) + +async def main(): + async with env: + # Uses BASE_URL from baseline + API_KEY from per-call + result = await env.shell( + 'curl -H "Authorization: Bearer $API_KEY" $BASE_URL/data', + env={'API_KEY': 'sk-test-123'}, + ) + print(result.output) +``` + +## Interactive Processes + +For long-running or interactive workloads, use [`create_process()`][pydantic_harness.environments.ExecutionEnvironment.create_process] to get an [`ExecutionProcess`][pydantic_harness.environments.ExecutionProcess] with bidirectional streaming I/O: + +```python {title="environments_process.py" test="skip"} +from pydantic_harness.environments.local import LocalEnvironment + +env = LocalEnvironment() + +async def main(): + async with env: + async with await env.create_process('python3 -u worker.py') as proc: + await proc.send(b'{"task": "analyze"}\n') + response = await proc.recv(timeout=10.0) + print(response.decode()) +``` + +## Execution Model + +Each call to `shell()` or `create_process()` starts a fresh process. Shell state (like `cd`, shell variables) does not persist between calls. This is the same model used by other coding agents like Claude Code and Codex. + +To run commands in a specific directory, chain them: + +```python {title="environments_chaining.py" test="skip" lint="skip"} +result = await env.shell('cd /some/path && python script.py') +``` + +Filesystem changes (created files, installed packages) persist for the lifetime of the environment. + +## Building a Custom Environment + +You can implement [`ExecutionEnvironment`][pydantic_harness.environments.ExecutionEnvironment] to integrate with any execution backend. The only abstract member is `capabilities`; override the methods that match your declared capabilities. Override [`create_process()`][pydantic_harness.environments.ExecutionEnvironment.create_process] if you need interactive process support. + +```python {title="environments_custom.py" test="skip" lint="skip"} +from typing import Literal + +from pydantic_harness.environments import ExecutionEnvironment, ExecutionProcess, ExecutionResult, FileInfo +from pydantic_harness.environments._base import Capability + +class MyCloudEnvironment(ExecutionEnvironment): + @property + def capabilities(self) -> frozenset[Capability]: + return frozenset({'shell', 'read_file', 'write_file', 'replace_str', 'ls', 'glob', 'grep'}) + + async def shell( + self, command: str, *, timeout: float | None = 120, env: dict[str, str] | None = None + ) -> ExecutionResult: + # Run a command in your cloud environment + ... + + async def read_file( + self, path: str, *, offset: int = 0, limit: int = 2000 + ) -> str | bytes: + ... + + async def write_file(self, path: str, content: str | bytes) -> None: + ... + + async def replace_str( + self, path: str, old: str, new: str, *, replace_all: bool = False + ) -> int: + ... + + async def ls(self, path: str = '.') -> list[FileInfo]: + ... + + async def glob(self, pattern: str, *, path: str = '.') -> list[str]: + ... + + async def grep( + self, + pattern: str, + *, + path: str | None = None, + glob_pattern: str | None = None, + output_mode: Literal['content', 'files_with_matches', 'count'] = 'content', + ) -> str: + ... +``` diff --git a/examples/code_execution/__init__.py b/examples/code_execution/__init__.py new file mode 100644 index 0000000..9ff8aac --- /dev/null +++ b/examples/code_execution/__init__.py @@ -0,0 +1,14 @@ +"""Code Execution Examples. + +Available Examples: + - batch_operations.py: Create multiple calendar events using code execution (no external deps) + - follow_the_money.py: Fraud ring detection via transaction graph traversal (no external deps) + - github_pr_analysis.py: PR velocity analysis via GitHub REST API (requires GITHUB_PERSONAL_ACCESS_TOKEN) + - pr_comment_buckets.py: PR discussion intensity analysis via GitHub MCP (requires GITHUB_PERSONAL_ACCESS_TOKEN) + +Run: + uv run -m examples.code_execution.batch_operations + uv run -m examples.code_execution.follow_the_money + uv run -m examples.code_execution.github_pr_analysis + uv run -m examples.code_execution.pr_comment_buckets +""" diff --git a/examples/code_execution/batch_operations.py b/examples/code_execution/batch_operations.py new file mode 100644 index 0000000..4e1b3cd --- /dev/null +++ b/examples/code_execution/batch_operations.py @@ -0,0 +1,226 @@ +"""Code Execution Example: Batch Calendar Event Creation. + +This example shows how code execution reduces LLM roundtrips when creating multiple +calendar events. With traditional tool calling, each event requires a separate +roundtrip. With code execution, the LLM writes a loop that creates all events in one go. + +Run: + uv run -m examples.code_execution.batch_operations +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Any + +import logfire + +from pydantic_ai import Agent +from pydantic_harness.environments.monty import MontyEnvironment +from pydantic_ai.messages import ModelResponse, RetryPromptPart +from pydantic_ai.run import AgentRunResult +from pydantic_ai.toolsets import FunctionToolset +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset + +# ============================================================================= +# Configuration +# ============================================================================= + +PROMPT = """ +Create calendar events for a daily standup meeting at 9:00 AM for every day +in the first week of January 2025 (January 1-7). + +Return a summary with: +- Total events created +- List of all event dates +""" + +MODEL = 'gateway/anthropic:claude-sonnet-4-5' +MAX_RETRIES = 3 + +# ============================================================================= +# Mock Calendar Tools +# ============================================================================= + +# Simulated calendar storage +_calendar_events: list[dict[str, Any]] = [] + + +def create_calendar_event(title: str, date: str, time: str) -> dict[str, Any]: + """Create a calendar event. + + Args: + title: The title of the event. + date: The date of the event in YYYY-MM-DD format. + time: The time of the event in HH:MM format. + + Returns: + The created event with its ID. + """ + event_id = len(_calendar_events) + 1 + event = {'id': event_id, 'title': title, 'date': date, 'time': time} + _calendar_events.append(event) + return {'success': True, 'event_id': event_id, 'event': event} + + +def list_calendar_events() -> list[dict[str, Any]]: + """List all calendar events. + + Returns: + List of all calendar events. + """ + return _calendar_events.copy() + + +def create_toolset() -> FunctionToolset[None]: + """Create the calendar toolset.""" + toolset: FunctionToolset[None] = FunctionToolset() + toolset.add_function(create_calendar_event) + toolset.add_function(list_calendar_events) + return toolset + + +# ============================================================================= +# Agent Factories +# ============================================================================= + + +def create_tool_calling_agent(toolset: FunctionToolset[None]) -> Agent[None, str]: + """Create agent with standard tool calling.""" + return Agent( + MODEL, + toolsets=[toolset], + system_prompt='You are a calendar assistant. Use the available tools to manage calendar events.', + ) + + +def create_code_execution_agent(toolset: FunctionToolset[None]) -> Agent[None, str]: + """Create agent with code execution (tools as Python functions).""" + environment = MontyEnvironment() + code_toolset: CodeExecutionToolset[None] = CodeExecutionToolset( + environment, + toolset=toolset, + max_retries=MAX_RETRIES, + ) + return Agent( + MODEL, + toolsets=[code_toolset], + system_prompt='You are a calendar assistant. Use the available tools to manage calendar events.', + ) + + +# ============================================================================= +# Metrics Collection +# ============================================================================= + + +@dataclass +class RunMetrics: + """Metrics collected from an agent run.""" + + mode: str + request_count: int + input_tokens: int + output_tokens: int + retry_count: int + output: str + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +def extract_metrics(result: AgentRunResult[str], mode: str) -> RunMetrics: + """Extract metrics from agent result.""" + request_count = 0 + input_tokens = 0 + output_tokens = 0 + retry_count = 0 + + for msg in result.all_messages(): + if isinstance(msg, ModelResponse): + request_count += 1 + if msg.usage: + input_tokens += msg.usage.input_tokens or 0 + output_tokens += msg.usage.output_tokens or 0 + for part in getattr(msg, 'parts', []): + if isinstance(part, RetryPromptPart): + retry_count += 1 + + return RunMetrics( + mode=mode, + request_count=request_count, + input_tokens=input_tokens, + output_tokens=output_tokens, + retry_count=retry_count, + output=result.output, + ) + + +# ============================================================================= +# Run Functions +# ============================================================================= + + +async def run_tool_calling(toolset: FunctionToolset[None]) -> RunMetrics: + """Run with standard tool calling.""" + global _calendar_events + _calendar_events = [] # Reset calendar + + with logfire.span('tool_calling'): + agent = create_tool_calling_agent(toolset) + result = await agent.run(PROMPT) + return extract_metrics(result, 'tool_calling') + + +async def run_code_execution(toolset: FunctionToolset[None]) -> RunMetrics: + """Run with code execution tool calling.""" + global _calendar_events + _calendar_events = [] # Reset calendar + + with logfire.span('code_execution_tool_calling'): + agent = create_code_execution_agent(toolset) + code_toolset = agent.toolsets[0] + async with code_toolset: + result = await agent.run(PROMPT) + return extract_metrics(result, 'code_execution') + + +# ============================================================================= +# Main Demo +# ============================================================================= + + +def log_metrics(metrics: RunMetrics) -> None: + """Log metrics to logfire.""" + logfire.info( + '{mode} completed: {requests} requests, {tokens} tokens', + mode=metrics.mode, + requests=metrics.request_count, + tokens=metrics.total_tokens, + input_tokens=metrics.input_tokens, + output_tokens=metrics.output_tokens, + retries=metrics.retry_count, + ) + + +async def main() -> None: + logfire.configure(service_name='code-execution-batch-demo') + logfire.instrument_pydantic_ai() + + toolset = create_toolset() + + with logfire.span('demo_tool_calling'): + trad = await run_tool_calling(toolset) + log_metrics(trad) + + with logfire.span('demo_code_execution'): + code = await run_code_execution(toolset) + log_metrics(code) + + print('View traces: https://logfire.pydantic.dev') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/code_execution/follow_the_money.py b/examples/code_execution/follow_the_money.py new file mode 100644 index 0000000..22bfa87 --- /dev/null +++ b/examples/code_execution/follow_the_money.py @@ -0,0 +1,1601 @@ +"""Code Execution Example: Fraud Ring Detection via Transaction Graph Traversal. + +This example demonstrates where code execution doesn't just reduce roundtrips -- it makes +the task qualitatively easier to solve correctly. The scenario: given a flagged account, +trace money flows through a layered transaction network, identify convergence points +(accounts receiving funds from multiple upstream sources), and rank suspects. + +The API is deliberately decomposed and hostile to manual traversal: + + 1. Paginated results (3 per page) -- the LLM must loop to get the full picture. + 2. Multi-currency transactions requiring FX rate lookups and arithmetic. + 3. Batch wire transfers that hide real recipients behind an extra API call. + 4. Verbose records with many irrelevant fields per response. + +With traditional tool calling, the LLM must: +- Make 3-4 round-trips per account just to paginate through transactions +- Call get_exchange_rate + mentally multiply for every foreign-currency transfer +- Notice that "batch" transactions need expansion, then call get_batch_details +- Mentally track visited accounts, running totals, and source sets across ~100 calls +- All of this correctly across 3 hops, ~15 accounts, and ~45 transactions + +With code execution, the LLM writes a BFS loop with a while-page loop inside, does +`amount * rate` inline, expands batches in a for-loop, and returns a summary. + +Run: + uv run -m examples.code_execution.follow_the_money +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import Any + +import logfire + +from pydantic_ai import Agent +from pydantic_harness.environments.monty import MontyEnvironment +from pydantic_ai.messages import ModelResponse, RetryPromptPart +from pydantic_ai.run import AgentRunResult +from pydantic_ai.toolsets import FunctionToolset +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset + +# ============================================================================= +# Configuration +# ============================================================================= + +AMOUNT_THRESHOLD = 5000 +PAGE_SIZE = 3 + +PROMPT = """\ +Starting from account ACC-001, which has been flagged for suspicious activity, \ +trace all outgoing money flows up to 3 hops deep. Only follow transactions of \ +$5,000 or more (after converting to USD using the exchange rate tool for any \ +non-USD transactions) -- smaller amounts are likely legitimate business expenses. + +Some transactions are batch wire transfers that bundle payments to multiple \ +recipients. You MUST expand these by calling the batch details tool -- the \ +transaction list and transaction detail endpoints only show the batch total, \ +not where the money actually went. + +Build the full transaction network and identify convergence points -- accounts \ +that receive money from 2 or more different upstream accounts in the traced \ +network. These suggest layering. Flag all convergence points and report their \ +names, number of upstream sources, and total inflows (converted to USD).\ +""" + +MODEL = 'gateway/anthropic:claude-sonnet-4-5' +MAX_RETRIES = 5 + +# ============================================================================= +# Mock Transaction Network +# ============================================================================= + +# Account names -- used by the transaction builder and ground truth +_ACCOUNT_NAMES: dict[str, str] = { + 'ACC-001': 'Viktor Petrov', + 'ACC-002': 'Oceanic Trading LLC', + 'ACC-003': 'Baltic Shipping Co', + 'ACC-004': 'Maria Santos', + 'ACC-005': 'Sunrise Consulting', + 'ACC-006': 'Digital Media Partners', + 'ACC-007': 'Northern Logistics', + 'ACC-008': 'Quick Cash Services', + 'ACC-009': 'Golden Gate Holdings', + 'ACC-010': 'Harbor Real Estate', + 'ACC-011': 'Cafe Milano', + 'ACC-012': 'Island Getaway Travel', + 'ACC-013': 'City Newsstand', + 'ACC-014': 'Metro Dry Cleaning', + 'ACC-015': 'Pine Street Deli', +} + +_EXTERNAL_NAMES: dict[str, str] = { + 'EXT-101': 'Pacific Rim Imports', + 'EXT-102': 'Nordic Freight Alliance', + 'EXT-103': 'Apex Strategy Group', + 'EXT-104': 'Broadview Media Inc', + 'EXT-105': 'Continental Warehousing', + 'EXT-106': 'Premier Currency Exchange', + 'EXT-107': 'Westfield Capital Partners', + 'EXT-108': 'Santos Family Trust', + 'EXT-109': 'Eastern European Trade Consortium', + 'EXT-110': 'Meridian Advisory Services', + 'EXT-111': 'Bay Area Venture Fund', + 'EXT-112': 'Sierra Investment Corp', + 'EXT-113': 'Creative Solutions Agency', + 'EXT-114': 'TransAtlantic Commodities', + 'EXT-115': 'Caspian Energy Group', + 'EXT-116': 'Silk Road Textiles', + 'EXT-117': 'Danube River Freight', + 'EXT-118': 'Coral Bay Exports', + 'EXT-119': 'Summit Peak Advisors', + 'EXT-120': 'Black Sea Minerals', + 'EXT-121': 'Fjord Line Cargo', + 'EXT-122': 'Redwood Strategies LLC', + 'EXT-123': 'Amber Coast Holdings', +} + +_ALL_NAMES: dict[str, str] = {**_ACCOUNT_NAMES, **_EXTERNAL_NAMES} + +# Verbose account records -- realistic financial API response shape. +# The LLM only needs id/name for the task, but gets 15-20 fields per call. +_ACCOUNTS: dict[str, dict[str, Any]] = { + 'ACC-001': { + 'id': 'ACC-001', + 'name': 'Viktor Petrov', + 'type': 'individual', + 'status': 'flagged', + 'date_of_birth': '1978-06-14', + 'nationality': 'Russian Federation', + 'address': '42 Nevsky Prospect, Apt 15, St. Petersburg, Russia 191025', + 'phone': '+7-812-555-0147', + 'email': 'v.petrov@petrovconsulting.com', + 'occupation': 'Import/Export Consultant', + 'employer': 'Self-employed', + 'annual_income': '$180,000-$250,000', + 'tax_id': 'XX-XXX7823', + 'risk_score': 78, + 'kyc_status': 'review_pending', + 'last_review_date': '2024-01-05', + 'account_opened': '2020-03-15', + 'account_balance': 12450.00, + }, + 'ACC-002': { + 'id': 'ACC-002', + 'name': 'Oceanic Trading LLC', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2018-33847', + 'jurisdiction': 'Delaware', + 'incorporation_date': '2018-07-22', + 'registered_agent': 'National Corporate Services', + 'address': '1200 Harbor Blvd, Suite 400, Wilmington, DE 19801', + 'phone': '+1-302-555-0182', + 'email': 'accounts@oceanictrading.com', + 'industry': 'International Trade', + 'annual_revenue_band': '$5M-$10M', + 'employee_count': 28, + 'primary_contact': 'Elena Vasquez, CFO', + 'tax_id': 'XX-XXX4190', + 'risk_score': 45, + 'kyc_status': 'current', + 'last_review_date': '2023-09-18', + 'account_opened': '2018-08-01', + 'account_balance': 234780.50, + }, + 'ACC-003': { + 'id': 'ACC-003', + 'name': 'Baltic Shipping Co', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2017-28934', + 'jurisdiction': 'New York', + 'incorporation_date': '2017-03-10', + 'registered_agent': 'Harbor Legal Services', + 'address': '45 Water Street, 12th Floor, New York, NY 10004', + 'phone': '+1-212-555-0293', + 'email': 'finance@balticshipping.com', + 'industry': 'Maritime Logistics', + 'annual_revenue_band': '$10M-$25M', + 'employee_count': 65, + 'primary_contact': 'Andrei Volkov, Director of Operations', + 'tax_id': 'XX-XXX8821', + 'risk_score': 38, + 'kyc_status': 'current', + 'last_review_date': '2023-11-02', + 'account_opened': '2017-04-15', + 'account_balance': 567200.00, + }, + 'ACC-004': { + 'id': 'ACC-004', + 'name': 'Maria Santos', + 'type': 'individual', + 'status': 'active', + 'date_of_birth': '1985-11-28', + 'nationality': 'Brazilian', + 'address': '789 Brickell Ave, Apt 2201, Miami, FL 33131', + 'phone': '+1-305-555-0174', + 'email': 'maria.santos@gmail.com', + 'occupation': 'Real Estate Agent', + 'employer': 'Luxe Properties International', + 'annual_income': '$95,000-$150,000', + 'tax_id': 'XX-XXX5567', + 'risk_score': 32, + 'kyc_status': 'current', + 'last_review_date': '2023-12-15', + 'account_opened': '2021-06-20', + 'account_balance': 45890.25, + }, + 'ACC-005': { + 'id': 'ACC-005', + 'name': 'Sunrise Consulting', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2019-41205', + 'jurisdiction': 'Nevada', + 'incorporation_date': '2019-01-08', + 'registered_agent': 'Silver State Filings', + 'address': '3960 Howard Hughes Pkwy, Suite 500, Las Vegas, NV 89169', + 'phone': '+1-702-555-0138', + 'email': 'billing@sunriseconsulting.net', + 'industry': 'Management Consulting', + 'annual_revenue_band': '$1M-$5M', + 'employee_count': 8, + 'primary_contact': 'David Park, Managing Partner', + 'tax_id': 'XX-XXX9034', + 'risk_score': 51, + 'kyc_status': 'current', + 'last_review_date': '2023-10-30', + 'account_opened': '2019-02-14', + 'account_balance': 189340.00, + }, + 'ACC-006': { + 'id': 'ACC-006', + 'name': 'Digital Media Partners', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2020-55672', + 'jurisdiction': 'California', + 'incorporation_date': '2020-05-14', + 'registered_agent': 'Pacific Registered Agents', + 'address': '2049 Century Park East, Suite 2500, Los Angeles, CA 90067', + 'phone': '+1-310-555-0216', + 'email': 'ap@digitalmediapartners.co', + 'industry': 'Digital Marketing', + 'annual_revenue_band': '$1M-$5M', + 'employee_count': 15, + 'primary_contact': 'Sarah Kim, Head of Finance', + 'tax_id': 'XX-XXX2288', + 'risk_score': 41, + 'kyc_status': 'current', + 'last_review_date': '2023-08-22', + 'account_opened': '2020-06-01', + 'account_balance': 112450.75, + }, + 'ACC-007': { + 'id': 'ACC-007', + 'name': 'Northern Logistics', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2016-19483', + 'jurisdiction': 'Illinois', + 'incorporation_date': '2016-09-30', + 'registered_agent': 'Midwest Corporate Services', + 'address': '233 S Wacker Drive, Suite 8400, Chicago, IL 60606', + 'phone': '+1-312-555-0187', + 'email': 'invoices@northernlogistics.com', + 'industry': 'Freight & Logistics', + 'annual_revenue_band': '$5M-$10M', + 'employee_count': 42, + 'primary_contact': 'Michael Torres, VP Finance', + 'tax_id': 'XX-XXX6145', + 'risk_score': 29, + 'kyc_status': 'current', + 'last_review_date': '2023-07-14', + 'account_opened': '2016-10-15', + 'account_balance': 423100.00, + }, + 'ACC-008': { + 'id': 'ACC-008', + 'name': 'Quick Cash Services', + 'type': 'business', + 'status': 'active', + 'registration_number': 'MSB-2019-08841', + 'jurisdiction': 'Florida', + 'incorporation_date': '2019-11-05', + 'registered_agent': 'Sunshine State Filings', + 'address': '100 SE 2nd Street, Suite 3200, Miami, FL 33131', + 'phone': '+1-786-555-0144', + 'email': 'compliance@quickcashsvc.com', + 'industry': 'Money Services Business', + 'annual_revenue_band': '$1M-$5M', + 'employee_count': 11, + 'primary_contact': 'Roberto Diaz, Compliance Officer', + 'tax_id': 'XX-XXX7790', + 'risk_score': 67, + 'kyc_status': 'current', + 'last_review_date': '2024-01-02', + 'account_opened': '2019-12-01', + 'account_balance': 78920.50, + }, + 'ACC-009': { + 'id': 'ACC-009', + 'name': 'Golden Gate Holdings', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2019-45892', + 'jurisdiction': 'Delaware', + 'incorporation_date': '2019-03-15', + 'registered_agent': 'National Corporate Services', + 'address': '1455 Market Street, Suite 600, San Francisco, CA 94103', + 'phone': '+1-415-555-0189', + 'email': 'ir@goldengateholdings.com', + 'industry': 'Investment Management', + 'annual_revenue_band': '$1M-$5M', + 'employee_count': 12, + 'primary_contact': 'James Chen, Managing Director', + 'tax_id': 'XX-XXX4521', + 'risk_score': 62, + 'kyc_status': 'current', + 'last_review_date': '2023-11-20', + 'account_opened': '2019-04-01', + 'account_balance': 1245000.00, + }, + 'ACC-010': { + 'id': 'ACC-010', + 'name': 'Harbor Real Estate', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2020-62341', + 'jurisdiction': 'California', + 'incorporation_date': '2020-02-28', + 'registered_agent': 'Pacific Registered Agents', + 'address': '101 California Street, Suite 3800, San Francisco, CA 94111', + 'phone': '+1-415-555-0231', + 'email': 'closings@harborrealestate.com', + 'industry': 'Real Estate', + 'annual_revenue_band': '$5M-$10M', + 'employee_count': 22, + 'primary_contact': 'Linda Wu, Broker', + 'tax_id': 'XX-XXX3392', + 'risk_score': 25, + 'kyc_status': 'current', + 'last_review_date': '2023-06-10', + 'account_opened': '2020-03-15', + 'account_balance': 892000.00, + }, + 'ACC-011': { + 'id': 'ACC-011', + 'name': 'Cafe Milano', + 'type': 'business', + 'status': 'active', + 'registration_number': 'DBA-2021-08823', + 'jurisdiction': 'California', + 'incorporation_date': '2021-04-12', + 'registered_agent': 'Self', + 'address': '2154 Union Street, San Francisco, CA 94123', + 'phone': '+1-415-555-0198', + 'email': 'owner@cafemilano-sf.com', + 'industry': 'Food & Beverage', + 'annual_revenue_band': '$500K-$1M', + 'employee_count': 9, + 'primary_contact': 'Giuseppe Rossi, Owner', + 'tax_id': 'XX-XXX1147', + 'risk_score': 15, + 'kyc_status': 'current', + 'last_review_date': '2023-05-20', + 'account_opened': '2021-05-01', + 'account_balance': 34500.00, + }, + 'ACC-012': { + 'id': 'ACC-012', + 'name': 'Island Getaway Travel', + 'type': 'business', + 'status': 'active', + 'registration_number': 'LLC-2020-71456', + 'jurisdiction': 'Florida', + 'incorporation_date': '2020-08-19', + 'registered_agent': 'Sunshine State Filings', + 'address': '800 Brickell Ave, Suite 1100, Miami, FL 33131', + 'phone': '+1-305-555-0156', + 'email': 'bookings@islandgetaway.travel', + 'industry': 'Travel & Tourism', + 'annual_revenue_band': '$1M-$5M', + 'employee_count': 14, + 'primary_contact': 'Carmen Reyes, General Manager', + 'tax_id': 'XX-XXX8834', + 'risk_score': 20, + 'kyc_status': 'current', + 'last_review_date': '2023-09-05', + 'account_opened': '2020-09-01', + 'account_balance': 67800.00, + }, + 'ACC-013': { + 'id': 'ACC-013', + 'name': 'City Newsstand', + 'type': 'business', + 'status': 'active', + 'registration_number': 'DBA-2015-04412', + 'jurisdiction': 'New York', + 'incorporation_date': '2015-06-01', + 'registered_agent': 'Self', + 'address': '350 Fifth Avenue, Lobby Level, New York, NY 10118', + 'phone': '+1-212-555-0133', + 'email': 'citynewsstand@aol.com', + 'industry': 'Retail', + 'annual_revenue_band': '$100K-$500K', + 'employee_count': 3, + 'primary_contact': 'Frank Abagnale Jr, Owner', + 'tax_id': 'XX-XXX2201', + 'risk_score': 10, + 'kyc_status': 'current', + 'last_review_date': '2022-12-01', + 'account_opened': '2015-07-01', + 'account_balance': 8900.00, + }, + 'ACC-014': { + 'id': 'ACC-014', + 'name': 'Metro Dry Cleaning', + 'type': 'business', + 'status': 'active', + 'registration_number': 'DBA-2018-12290', + 'jurisdiction': 'New York', + 'incorporation_date': '2018-03-22', + 'registered_agent': 'Self', + 'address': '891 Lexington Ave, New York, NY 10065', + 'phone': '+1-212-555-0177', + 'email': 'metrodrycleaning@gmail.com', + 'industry': 'Laundry Services', + 'annual_revenue_band': '$100K-$500K', + 'employee_count': 5, + 'primary_contact': 'Tony Park, Owner', + 'tax_id': 'XX-XXX6678', + 'risk_score': 12, + 'kyc_status': 'current', + 'last_review_date': '2023-02-15', + 'account_opened': '2018-04-01', + 'account_balance': 15200.00, + }, + 'ACC-015': { + 'id': 'ACC-015', + 'name': 'Pine Street Deli', + 'type': 'business', + 'status': 'active', + 'registration_number': 'DBA-2019-09917', + 'jurisdiction': 'Illinois', + 'incorporation_date': '2019-07-14', + 'registered_agent': 'Self', + 'address': '47 W Pine Street, Chicago, IL 60610', + 'phone': '+1-312-555-0129', + 'email': 'pinestreetdeli@outlook.com', + 'industry': 'Food & Beverage', + 'annual_revenue_band': '$100K-$500K', + 'employee_count': 4, + 'primary_contact': 'Sam Goldstein, Owner', + 'tax_id': 'XX-XXX3345', + 'risk_score': 8, + 'kyc_status': 'current', + 'last_review_date': '2023-04-10', + 'account_opened': '2019-08-01', + 'account_balance': 11300.00, + }, +} + +# ============================================================================= +# Exchange Rates +# ============================================================================= + +# Simplified: one rate per currency pair (mid-January 2024 approximations). +_FX_RATES: dict[tuple[str, str], float] = { + ('EUR', 'USD'): 1.0847, + ('GBP', 'USD'): 1.2693, + ('CHF', 'USD'): 1.1782, + ('USD', 'EUR'): 0.9219, + ('USD', 'GBP'): 0.7878, + ('USD', 'CHF'): 0.8488, +} + + +def _to_usd(amount: float, currency: str) -> float: + """Convert an amount to USD.""" + if currency == 'USD': + return amount + rate = _FX_RATES[(currency, 'USD')] + return round(amount * rate, 2) + + +# ============================================================================= +# Transaction Data +# ============================================================================= + +# Compact definitions: (txn_id, from, to, amount, currency, date, memo, type) +# Expanded into verbose records by _build_transaction(). +# Batch transactions use type='batch_wire' and have a batch_id. +_COMPACT_TXNS: list[tuple[str, str, str, float, str, str, str, str]] = [ + # --- Layer 0: ACC-001 outgoing --- + ( + 'TXN-001', + 'ACC-001', + 'ACC-002', + 43580.00, + 'EUR', + '2024-01-15', + 'Consulting services - Q1 retainer', + 'wire', + ), + ( + 'TXN-002', + 'ACC-001', + 'ACC-003', + 25050.00, + 'GBP', + '2024-01-16', + 'Shipping contract - Baltic route', + 'wire', + ), + ( + 'TXN-003', + 'ACC-001', + 'ACC-004', + 23150.00, + 'USD', + '2024-01-17', + 'Personal transfer - living expenses', + 'wire', + ), + # --- Layer 0: ACC-001 small outgoing (noise, all below threshold) --- + ( + 'TXN-034', + 'ACC-001', + 'ACC-011', + 1200.00, + 'USD', + '2024-01-15', + 'Catering deposit - private event', + 'ach', + ), + ( + 'TXN-035', + 'ACC-001', + 'ACC-013', + 850.00, + 'USD', + '2024-01-16', + 'Newspaper subscription - annual', + 'ach', + ), + ( + 'TXN-036', + 'ACC-001', + 'ACC-014', + 2100.00, + 'USD', + '2024-01-18', + 'Suit dry cleaning - monthly', + 'ach', + ), + ( + 'TXN-037', + 'ACC-001', + 'ACC-015', + 400.00, + 'USD', + '2024-01-19', + 'Lunch - client meeting', + 'ach', + ), + # --- Layer 1: ACC-002 outgoing --- + # TXN-004 is a BATCH wire -- real recipients hidden behind BATCH-001 + ( + 'TXN-004', + 'ACC-002', + 'BATCH', + 28400.00, + 'USD', + '2024-01-19', + 'Batch wire - multiple payees', + 'batch_wire', + ), + ( + 'TXN-005', + 'ACC-002', + 'ACC-006', + 14570.00, + 'EUR', + '2024-01-20', + 'Marketing campaign - Q1 digital', + 'wire', + ), + ( + 'TXN-006', + 'ACC-002', + 'ACC-013', + 2100.00, + 'USD', + '2024-01-19', + 'Office supplies and newspapers', + 'ach', + ), + ( + 'TXN-038', + 'ACC-002', + 'ACC-014', + 3200.00, + 'USD', + '2024-01-20', + 'Uniform cleaning - warehouse staff', + 'ach', + ), + # --- Layer 1: ACC-003 outgoing --- + ( + 'TXN-007', + 'ACC-003', + 'ACC-005', + 18500.00, + 'USD', + '2024-01-20', + 'Freight forwarding - container lot', + 'wire', + ), + ( + 'TXN-008', + 'ACC-003', + 'ACC-007', + 9825.00, + 'GBP', + '2024-01-21', + 'Warehouse rental - Q1 lease', + 'wire', + ), + ( + 'TXN-009', + 'ACC-003', + 'ACC-014', + 1950.00, + 'USD', + '2024-01-20', + 'Uniform cleaning service', + 'ach', + ), + ( + 'TXN-039', + 'ACC-003', + 'ACC-015', + 800.00, + 'USD', + '2024-01-21', + 'Staff lunch order', + 'ach', + ), + # --- Layer 1: ACC-004 outgoing --- + # TXN-010 is a BATCH wire -- real recipients hidden behind BATCH-002 + ( + 'TXN-010', + 'ACC-004', + 'BATCH', + 13600.00, + 'USD', + '2024-01-22', + 'Batch wire - media + investment', + 'batch_wire', + ), + ( + 'TXN-011', + 'ACC-004', + 'ACC-008', + 9750.00, + 'USD', + '2024-01-21', + 'Currency exchange - BRL to USD', + 'wire', + ), + ( + 'TXN-040', + 'ACC-004', + 'ACC-011', + 1200.00, + 'USD', + '2024-01-22', + 'Event catering deposit', + 'ach', + ), + ( + 'TXN-041', + 'ACC-004', + 'ACC-012', + 3400.00, + 'USD', + '2024-01-23', + 'Travel booking - Cancun package', + 'ach', + ), + # --- Layer 2: ACC-005 outgoing --- + ( + 'TXN-012', + 'ACC-005', + 'ACC-009', + 38495.00, + 'EUR', + '2024-01-24', + 'Investment deposit - Q1 2024 tranche', + 'wire', + ), + ( + 'TXN-013', + 'ACC-005', + 'ACC-010', + 6200.00, + 'USD', + '2024-01-25', + 'Property deposit - Unit 4B', + 'wire', + ), + ( + 'TXN-042', + 'ACC-005', + 'ACC-015', + 900.00, + 'USD', + '2024-01-24', + 'Staff lunch catering', + 'ach', + ), + # --- Layer 2: ACC-006 outgoing --- + ( + 'TXN-014', + 'ACC-006', + 'ACC-009', + 22300.00, + 'USD', + '2024-01-25', + 'Equity purchase - Series B', + 'wire', + ), + ( + 'TXN-015', + 'ACC-006', + 'ACC-011', + 2800.00, + 'EUR', + '2024-01-26', + 'Catering contract - annual retainer', + 'ach', + ), + # --- Layer 2: ACC-007 outgoing --- + ( + 'TXN-016', + 'ACC-007', + 'ACC-009', + 10250.00, + 'USD', + '2024-01-26', + 'Fleet financing - vehicle lease', + 'wire', + ), + ( + 'TXN-017', + 'ACC-007', + 'ACC-015', + 1800.00, + 'USD', + '2024-01-25', + 'Lunch catering - staff event', + 'ach', + ), + # --- Layer 2: ACC-008 outgoing --- + ( + 'TXN-018', + 'ACC-008', + 'ACC-009', + 7150.00, + 'CHF', + '2024-01-27', + 'Wire transfer - client remittance', + 'wire', + ), + ( + 'TXN-019', + 'ACC-008', + 'ACC-012', + 1500.00, + 'USD', + '2024-01-26', + 'Travel booking - Caribbean package', + 'ach', + ), + # --- Noise: incoming from external accounts --- + ( + 'TXN-020', + 'EXT-101', + 'ACC-002', + 85000.00, + 'USD', + '2024-01-10', + 'Import duties settlement - FY2023', + 'wire', + ), + ( + 'TXN-021', + 'EXT-102', + 'ACC-003', + 62400.00, + 'EUR', + '2024-01-08', + 'Charter party payment - MV Nordic Star', + 'wire', + ), + ( + 'TXN-022', + 'EXT-103', + 'ACC-005', + 34750.00, + 'USD', + '2024-01-12', + 'Strategy engagement - Phase 2 milestone', + 'wire', + ), + ( + 'TXN-023', + 'EXT-104', + 'ACC-006', + 19200.00, + 'USD', + '2024-01-11', + 'Ad campaign management - Nov/Dec', + 'ach', + ), + ( + 'TXN-024', + 'EXT-105', + 'ACC-007', + 27800.00, + 'USD', + '2024-01-09', + 'Warehousing contract - quarterly', + 'wire', + ), + ( + 'TXN-025', + 'EXT-106', + 'ACC-008', + 15600.00, + 'CHF', + '2024-01-13', + 'FX settlement - batch 2024-01', + 'wire', + ), + ( + 'TXN-026', + 'EXT-107', + 'ACC-009', + 120000.00, + 'USD', + '2024-01-05', + 'Capital call - Fund III', + 'wire', + ), + ( + 'TXN-027', + 'EXT-108', + 'ACC-004', + 41300.00, + 'USD', + '2024-01-14', + 'Trust distribution - Q4 2023', + 'wire', + ), + ( + 'TXN-028', + 'EXT-109', + 'ACC-001', + 155000.00, + 'EUR', + '2024-01-03', + 'Trade consortium dividend', + 'wire', + ), + ( + 'TXN-029', + 'EXT-110', + 'ACC-005', + 22000.00, + 'USD', + '2024-01-07', + 'Advisory retainer - Jan 2024', + 'ach', + ), + ( + 'TXN-030', + 'EXT-111', + 'ACC-009', + 55000.00, + 'USD', + '2024-01-06', + 'LP commitment - tranche 2', + 'wire', + ), + ( + 'TXN-031', + 'EXT-112', + 'ACC-009', + 38500.00, + 'USD', + '2024-01-11', + 'Co-investment - Project Evergreen', + 'wire', + ), + ( + 'TXN-032', + 'EXT-113', + 'ACC-006', + 11750.00, + 'USD', + '2024-01-09', + 'Creative services - website redesign', + 'ach', + ), + ( + 'TXN-033', + 'EXT-114', + 'ACC-002', + 29300.00, + 'GBP', + '2024-01-12', + 'Commodity futures settlement', + 'wire', + ), + # --- More noise from externals (to fill pages) --- + ( + 'TXN-043', + 'EXT-115', + 'ACC-001', + 8500.00, + 'USD', + '2024-01-04', + 'Energy consulting retainer', + 'wire', + ), + ( + 'TXN-044', + 'EXT-116', + 'ACC-001', + 12000.00, + 'EUR', + '2024-01-06', + 'Textile import commission', + 'wire', + ), + ( + 'TXN-045', + 'EXT-117', + 'ACC-001', + 3200.00, + 'EUR', + '2024-01-09', + 'Freight brokerage fee', + 'ach', + ), + ( + 'TXN-046', + 'EXT-118', + 'ACC-002', + 45000.00, + 'USD', + '2024-01-07', + 'Coral Bay export contract - Phase 1', + 'wire', + ), + ( + 'TXN-047', + 'EXT-119', + 'ACC-002', + 18700.00, + 'USD', + '2024-01-14', + 'Advisory fee - M&A due diligence', + 'wire', + ), + ( + 'TXN-048', + 'EXT-120', + 'ACC-003', + 22500.00, + 'EUR', + '2024-01-05', + 'Mineral transport contract', + 'wire', + ), + ( + 'TXN-049', + 'EXT-121', + 'ACC-003', + 9800.00, + 'GBP', + '2024-01-11', + 'Cargo handling services - Q4', + 'wire', + ), + ( + 'TXN-050', + 'EXT-122', + 'ACC-005', + 15600.00, + 'USD', + '2024-01-09', + 'Strategy workshop facilitation', + 'wire', + ), + ( + 'TXN-051', + 'EXT-123', + 'ACC-005', + 7200.00, + 'USD', + '2024-01-13', + 'Holding company admin fee', + 'ach', + ), + ( + 'TXN-052', + 'EXT-115', + 'ACC-009', + 28000.00, + 'USD', + '2024-01-08', + 'Energy sector investment', + 'wire', + ), + ( + 'TXN-053', + 'EXT-119', + 'ACC-009', + 19500.00, + 'USD', + '2024-01-13', + 'Advisory placement fee', + 'wire', + ), +] + +# ============================================================================= +# Batch Data +# ============================================================================= + +# Batch transactions hide real recipients -- the parent transaction only shows +# a total amount. You must call get_batch_details() to see where the money went. +_BATCHES: dict[str, dict[str, Any]] = { + 'BATCH-001': { + 'batch_id': 'BATCH-001', + 'parent_transaction': 'TXN-004', + 'from_account': 'ACC-002', + 'from_account_name': 'Oceanic Trading LLC', + 'total_amount': 28400.00, + 'currency': 'USD', + 'date': '2024-01-19', + 'status': 'completed', + 'sub_transactions': [ + { + 'id': 'TXN-004-A', + 'to_account': 'ACC-005', + 'to_account_name': 'Sunrise Consulting', + 'amount': 22400.00, + 'currency': 'USD', + 'memo': 'Subcontractor payment - Project Alpha', + 'status': 'completed', + }, + { + 'id': 'TXN-004-B', + 'to_account': 'ACC-010', + 'to_account_name': 'Harbor Real Estate', + 'amount': 6000.00, + 'currency': 'USD', + 'memo': 'Property escrow deposit - Lot 7', + 'status': 'completed', + }, + ], + }, + 'BATCH-002': { + 'batch_id': 'BATCH-002', + 'parent_transaction': 'TXN-010', + 'from_account': 'ACC-004', + 'from_account_name': 'Maria Santos', + 'total_amount': 13600.00, + 'currency': 'USD', + 'date': '2024-01-22', + 'status': 'completed', + 'sub_transactions': [ + { + 'id': 'TXN-010-A', + 'to_account': 'ACC-006', + 'to_account_name': 'Digital Media Partners', + 'amount': 8600.00, + 'currency': 'USD', + 'memo': 'Media production - promotional video', + 'status': 'completed', + }, + { + 'id': 'TXN-010-B', + 'to_account': 'ACC-009', + 'to_account_name': 'Golden Gate Holdings', + 'amount': 5000.00, + 'currency': 'USD', + 'memo': 'Investment contribution - Series A', + 'status': 'completed', + }, + ], + }, +} + +_TXN_TO_BATCH: dict[str, str] = { + 'TXN-004': 'BATCH-001', + 'TXN-010': 'BATCH-002', +} + +_SWIFT_CODES = ('CHASUS33', 'BOFAUS3N', 'CITIUS33', 'WFBIUS6S') +_OFFICER_IDS = ('OFC-0892', 'OFC-1247', 'OFC-0651', 'OFC-0433') + + +def _build_transaction( + idx: int, + txn_id: str, + from_acc: str, + to_acc: str, + amount: float, + currency: str, + date: str, + memo: str, + txn_type: str, +) -> dict[str, Any]: + """Expand compact transaction tuple into a verbose API-style record.""" + day = int(date.split('-')[2]) + settlement_day = min(day + 2, 28) + prefix = { + 'wire': 'WR', + 'ach': 'ACH', + 'check': 'CHK', + 'internal': 'INT', + 'batch_wire': 'BW', + } + + record: dict[str, Any] = { + 'id': txn_id, + 'reference': f'{prefix[txn_type]}-2024-{1000 + idx:04d}-{from_acc}-{to_acc}', + 'from_account': from_acc, + 'from_account_name': _ALL_NAMES.get(from_acc, from_acc), + 'amount': amount, + 'currency': currency, + 'date': date, + 'settlement_date': f'{date[:8]}{settlement_day:02d}', + 'type': txn_type, + 'status': 'completed', + 'intermediary_bank': f'SWIFT: {_SWIFT_CODES[idx % len(_SWIFT_CODES)]}', + 'memo': memo, + 'compliance_flags': [], + 'processing_fee': round(amount * 0.0015, 2), + 'batch_id': f'BT-{date.replace("-", "")}-{(idx % 5) + 1}', + 'officer_id': _OFFICER_IDS[idx % len(_OFFICER_IDS)], + } + + if txn_type == 'batch_wire': + # Batch wires show "Multiple payees" -- real recipients are in the batch. + batch_id = _TXN_TO_BATCH[txn_id] + record['to_account'] = 'MULTIPLE' + record['to_account_name'] = 'Multiple payees (see batch details)' + record['batch_id'] = batch_id + else: + record['to_account'] = to_acc + record['to_account_name'] = _ALL_NAMES.get(to_acc, to_acc) + + return record + + +_TRANSACTIONS: dict[str, dict[str, Any]] = { + t[0]: _build_transaction(i, *t) for i, t in enumerate(_COMPACT_TXNS) +} + +_flagged_accounts: list[dict[str, Any]] = [] + + +# ============================================================================= +# Ground Truth (computed from the transaction data) +# ============================================================================= + + +def _compute_ground_truth() -> list[tuple[str, str, int, float]]: # noqa: C901 + """BFS from ACC-001 up to 3 hops, return convergence points. + + Only follows outgoing transactions >= AMOUNT_THRESHOLD (in USD). + Expands batch transactions to find real recipients. + Returns (account_id, name, source_count, total_inflow_usd) sorted by inflow desc. + """ + inflow: dict[str, float] = {} + sources: dict[str, set[str]] = {} + + def _record_edge( + from_acct: str, to_acct: str, amount: float, currency: str + ) -> list[str]: + """Record an edge if above threshold. Return list of new destinations.""" + usd = _to_usd(amount, currency) + if usd < AMOUNT_THRESHOLD: + return [] + inflow[to_acct] = inflow.get(to_acct, 0) + usd + if to_acct not in sources: + sources[to_acct] = set() + sources[to_acct].add(from_acct) + return [to_acct] + + visited: set[str] = set() + current_layer = ['ACC-001'] + + for _ in range(3): + next_layer: list[str] = [] + for acct in current_layer: + if acct in visited: + continue + visited.add(acct) + for txn in _TRANSACTIONS.values(): + if txn['from_account'] != acct: + continue + if txn['type'] == 'batch_wire': + # Expand batch to find real recipients + batch_id = txn['batch_id'] + batch = _BATCHES[batch_id] + for sub in batch['sub_transactions']: + dests = _record_edge( + acct, sub['to_account'], sub['amount'], sub['currency'] + ) + for d in dests: + if d not in visited: + next_layer.append(d) + else: + dests = _record_edge( + acct, txn['to_account'], txn['amount'], txn['currency'] + ) + for d in dests: + if d not in visited: + next_layer.append(d) + current_layer = next_layer + + convergence: list[tuple[str, str, int, float]] = [] + for acct_id, srcs in sources.items(): + if len(srcs) >= 2: + name = _ACCOUNT_NAMES.get(acct_id, acct_id) + convergence.append((acct_id, name, len(srcs), round(inflow[acct_id], 2))) + + convergence.sort(key=lambda x: (-x[3], x[0])) + return convergence + + +EXPECTED_CONVERGENCE = _compute_ground_truth() +EXPECTED_FLAG_IDS = {c[0] for c in EXPECTED_CONVERGENCE} + + +# ============================================================================= +# Mock Tools +# ============================================================================= + + +def list_account_transactions(account_id: str, page: int = 1) -> dict[str, Any]: + """List transactions for an account, paginated (3 per page). + + Returns transaction summaries -- use get_transaction() for full details + including amounts and currencies. + + Args: + account_id: The account ID to look up (e.g. "ACC-001"). + page: Page number (1-indexed). Each page returns up to 3 transactions. + + Returns: + A dict with: + - transactions: list of transaction summaries (id, direction, date, counterparty, category, status) + - page: current page number + - total_pages: total number of pages + - has_more: whether there are more pages + """ + all_results: list[dict[str, Any]] = [] + for txn_id, txn in _TRANSACTIONS.items(): + if txn['from_account'] == account_id: + counterparty = txn['to_account_name'] + if txn['type'] == 'batch_wire': + counterparty = 'Multiple payees (batch wire)' + all_results.append( + { + 'transaction_id': txn_id, + 'direction': 'outgoing', + 'date': txn['date'], + 'counterparty': counterparty, + 'category': f'{txn["type"]}_transfer', + 'status': txn['status'], + } + ) + elif txn['to_account'] == account_id: + all_results.append( + { + 'transaction_id': txn_id, + 'direction': 'incoming', + 'date': txn['date'], + 'counterparty': txn['from_account_name'], + 'category': f'{txn["type"]}_transfer', + 'status': txn['status'], + } + ) + all_results.sort(key=lambda x: x['date']) + + total = len(all_results) + total_pages = max(1, (total + PAGE_SIZE - 1) // PAGE_SIZE) + start = (page - 1) * PAGE_SIZE + end = start + PAGE_SIZE + page_results = all_results[start:end] + + return { + 'transactions': page_results, + 'page': page, + 'total_pages': total_pages, + 'has_more': page < total_pages, + } + + +def get_transaction(transaction_id: str) -> dict[str, Any]: + """Get full details of a specific transaction. + + Note: batch wire transfers show "Multiple payees" as the recipient. + Use get_batch_details() with the batch_id to see individual recipients. + + Args: + transaction_id: The transaction ID (e.g. "TXN-001"). + + Returns: + Full transaction record including amount, currency, accounts, dates, and metadata. + Batch wires include a batch_id field -- call get_batch_details() to expand. + """ + txn = _TRANSACTIONS.get(transaction_id) + if txn is None: + return {'error': f'Transaction {transaction_id} not found'} + return txn + + +def get_account_info(account_id: str) -> dict[str, Any]: + """Get account details. + + Args: + account_id: The account ID to look up (e.g. "ACC-001"). + + Returns: + Full account record including name, type, address, risk score, and more. + """ + account = _ACCOUNTS.get(account_id) + if account is None: + return {'error': f'Account {account_id} not found'} + return account + + +def get_exchange_rate(from_currency: str, to_currency: str) -> dict[str, Any]: + """Get the exchange rate between two currencies. + + Args: + from_currency: Source currency code (e.g. "EUR"). + to_currency: Target currency code (e.g. "USD"). + + Returns: + A dict with from_currency, to_currency, and rate. + """ + if from_currency == to_currency: + return {'from_currency': from_currency, 'to_currency': to_currency, 'rate': 1.0} + rate = _FX_RATES.get((from_currency, to_currency)) + if rate is None: + return {'error': f'No rate available for {from_currency}/{to_currency}'} + return {'from_currency': from_currency, 'to_currency': to_currency, 'rate': rate} + + +def get_batch_details(batch_id: str) -> dict[str, Any]: + """Get the sub-transactions within a batch wire transfer. + + Batch wires bundle multiple payments into a single transaction. This endpoint + returns the individual recipients and amounts. + + Args: + batch_id: The batch ID (e.g. "BATCH-001"). + + Returns: + Batch details including sub_transactions with individual recipients and amounts. + """ + batch = _BATCHES.get(batch_id) + if batch is None: + return {'error': f'Batch {batch_id} not found'} + return batch + + +def flag_account(account_id: str, reason: str) -> dict[str, Any]: + """Flag an account for suspicious activity. + + Args: + account_id: The account ID to flag. + reason: The reason for flagging. + + Returns: + Confirmation with the account ID and reason. + """ + record = {'account_id': account_id, 'reason': reason, 'status': 'flagged'} + _flagged_accounts.append(record) + return record + + +def create_toolset() -> FunctionToolset[None]: + """Create the transaction investigation toolset.""" + toolset: FunctionToolset[None] = FunctionToolset() + toolset.add_function(list_account_transactions) + toolset.add_function(get_transaction) + toolset.add_function(get_account_info) + toolset.add_function(get_exchange_rate) + toolset.add_function(get_batch_details) + toolset.add_function(flag_account) + return toolset + + +# ============================================================================= +# Agent Factories +# ============================================================================= + +SYSTEM_PROMPT = ( + 'You are a financial crime investigator. Use the available tools to trace ' + 'money flows and identify suspicious patterns in transaction networks. ' + 'Transaction lists are paginated -- make sure to fetch ALL pages. ' + 'Some transactions are batch wires -- you MUST expand them with get_batch_details ' + 'to see where the money actually went. ' + 'Foreign currency amounts must be converted to USD using get_exchange_rate ' + 'before comparing against thresholds or summing totals.' +) + + +def create_tool_calling_agent(toolset: FunctionToolset[None]) -> Agent[None, str]: + """Create agent with standard tool calling.""" + return Agent(MODEL, toolsets=[toolset], system_prompt=SYSTEM_PROMPT) + + +def create_code_execution_agent(toolset: FunctionToolset[None]) -> Agent[None, str]: + """Create agent with code execution (tools as Python functions).""" + environment = MontyEnvironment() + code_toolset: CodeExecutionToolset[None] = CodeExecutionToolset( + environment, + toolset=toolset, + max_retries=MAX_RETRIES, + ) + return Agent(MODEL, toolsets=[code_toolset], system_prompt=SYSTEM_PROMPT) + + +# ============================================================================= +# Metrics Collection +# ============================================================================= + + +@dataclass +class RunMetrics: + """Metrics collected from an agent run.""" + + mode: str + request_count: int + input_tokens: int + output_tokens: int + retry_count: int + output: str + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +def extract_metrics(result: AgentRunResult[str], mode: str) -> RunMetrics: + """Extract metrics from agent result.""" + request_count = 0 + input_tokens = 0 + output_tokens = 0 + retry_count = 0 + + for msg in result.all_messages(): + if isinstance(msg, ModelResponse): + request_count += 1 + if msg.usage: + input_tokens += msg.usage.input_tokens or 0 + output_tokens += msg.usage.output_tokens or 0 + for part in getattr(msg, 'parts', []): + if isinstance(part, RetryPromptPart): + retry_count += 1 + + return RunMetrics( + mode=mode, + request_count=request_count, + input_tokens=input_tokens, + output_tokens=output_tokens, + retry_count=retry_count, + output=result.output, + ) + + +# ============================================================================= +# Run Functions +# ============================================================================= + + +async def run_tool_calling(toolset: FunctionToolset[None]) -> RunMetrics: + """Run with standard tool calling.""" + global _flagged_accounts + _flagged_accounts = [] + + with logfire.span('tool_calling'): + agent = create_tool_calling_agent(toolset) + result = await agent.run(PROMPT) + return extract_metrics(result, 'tool_calling') + + +async def run_code_execution(toolset: FunctionToolset[None]) -> RunMetrics: + """Run with code execution tool calling.""" + global _flagged_accounts + _flagged_accounts = [] + + with logfire.span('code_execution_tool_calling'): + agent = create_code_execution_agent(toolset) + code_toolset = agent.toolsets[0] + async with code_toolset: + result = await agent.run(PROMPT) + return extract_metrics(result, 'code_execution') + + +# ============================================================================= +# Main Demo +# ============================================================================= + + +def verify_results(mode: str) -> None: + """Check flagged accounts against ground truth and log to logfire.""" + flagged_ids = {r['account_id'] for r in _flagged_accounts} + correct: set[str] = set() + missed: set[str] = set() + spurious: set[str] = set() + for fid in flagged_ids: + if fid in EXPECTED_FLAG_IDS: + correct.add(fid) + else: + spurious.add(fid) + for eid in EXPECTED_FLAG_IDS: + if eid not in flagged_ids: + missed.add(eid) + + if flagged_ids == EXPECTED_FLAG_IDS: + logfire.info( + '{mode} verification: PASS -- correctly flagged {count} accounts: {ids}', + mode=mode, + count=len(correct), + ids=', '.join(sorted(correct)), + ) + else: + logfire.error( + '{mode} verification: FAIL -- correct: {correct}, missed: {missed}, spurious: {spurious}', + mode=mode, + correct=', '.join(sorted(correct)) or 'none', + missed=', '.join(sorted(missed)) or 'none', + spurious=', '.join(sorted(spurious)) or 'none', + ) + + +def log_metrics(metrics: RunMetrics) -> None: + """Log metrics to logfire.""" + logfire.info( + '{mode} completed: {requests} requests, {tokens} tokens', + mode=metrics.mode, + requests=metrics.request_count, + tokens=metrics.total_tokens, + input_tokens=metrics.input_tokens, + output_tokens=metrics.output_tokens, + retries=metrics.retry_count, + ) + + +async def main() -> None: + logfire.configure(service_name='code-execution-follow-the-money') + logfire.instrument_pydantic_ai() + + toolset = create_toolset() + + # Print ground truth for debugging + print('Expected convergence points:') + for acct_id, name, src_count, total in EXPECTED_CONVERGENCE: + print(f' {acct_id} ({name}): {src_count} sources, ${total:,.2f} USD inflow') + print() + + # with logfire.span('demo_tool_calling'): + # trad = await run_tool_calling(toolset) + # log_metrics(trad) + # verify_results('tool_calling') + + with logfire.span('demo_code_execution'): + code = await run_code_execution(toolset) + log_metrics(code) + verify_results('code_execution') + + print('View traces: https://logfire.pydantic.dev') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/code_execution/github_pr_analysis.py b/examples/code_execution/github_pr_analysis.py new file mode 100644 index 0000000..eb45dfb --- /dev/null +++ b/examples/code_execution/github_pr_analysis.py @@ -0,0 +1,362 @@ +"""Code Execution Example: GitHub PR Velocity Analysis (Real API). + +Analyzes the last 30 merged pull requests in a real GitHub repository using +the REST API -- no mock data. + +For each PR the agent fetches reviews and changed files, creating a dependent +fan-out of ~60 API calls that traditional tool calling must spread across +many LLM roundtrips while code execution handles in a single loop. + + Level 1: List closed PRs, filter merged -> 1-2 API calls (paginated) + Level 2: Per PR, fetch reviews + files -> 60 API calls (30 x 2) + ~= 62 total + +Traditional parallel tool calling (best case): + Roundtrip 1: list PRs -> 30 PR objects enter context + Roundtrip 2: 30x get_pr_reviews in parallel -> 30 review payloads enter context + Roundtrip 3: 30x get_pr_files in parallel -> 30 file-list payloads enter context + Roundtrip 4: model mentally aggregates ~60 JSON blobs into a report + = 4 roundtrips, ~60 JSON payloads in context (~50-100k tokens of raw API data) + +Code execution: + Roundtrip 1: model writes a loop that paginates, fans out, computes inline + Roundtrip 2: model formats the ~300 token summary + = 2 roundtrips, only the final summary enters context + +Requires: + GITHUB_PERSONAL_ACCESS_TOKEN environment variable. + +Run: + uv run -m examples.code_execution.github_pr_analysis +""" + +from __future__ import annotations + +import asyncio +import os +from dataclasses import dataclass +from datetime import datetime +from typing import Any + +import httpx +import logfire + +from pydantic_ai import Agent +from pydantic_harness.environments.monty import MontyEnvironment +from pydantic_ai.messages import ModelResponse, RetryPromptPart +from pydantic_ai.run import AgentRunResult +from pydantic_ai.toolsets import FunctionToolset +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset + +# ============================================================================= +# Configuration +# ============================================================================= + +REPO = os.environ.get('GITHUB_REPO', 'pydantic/pydantic-ai') +TARGET_PRS = 30 +MODEL = 'gateway/anthropic:claude-sonnet-4-5' +MAX_RETRIES = 5 + +PROMPT = f"""\ +Analyze the last {TARGET_PRS} merged pull requests in "{REPO}". + +Use list_pull_requests to fetch closed PRs (not all closed PRs are merged -- \ +filter for entries where merged_at is not null). Paginate if the first page \ +doesn't yield {TARGET_PRS} merged PRs. + +For each merged PR, fetch its reviews and changed files. Then produce a report: + +1. Average and median hours_to_merge across all analyzed PRs +2. The 5 slowest PRs to merge -- title, author, hours_to_merge +3. Top 5 reviewers by total number of reviews submitted +4. Top 10 most frequently modified file paths across all PRs +5. Average number of files changed per PR +6. PR size distribution: small (<5 files), medium (5-20), large (>20) +7. Review coverage: percentage of PRs that received at least one review\ +""" + +SYSTEM_PROMPT = ( + 'You are a software engineering analyst. Use the available tools to query ' + 'the GitHub API and produce data-driven reports. ' + 'Each PR object includes an hours_to_merge field pre-computed from timestamps.' +) + +# ============================================================================= +# GitHub API Client +# ============================================================================= + +_client: httpx.Client | None = None + + +def _github_client() -> httpx.Client: + global _client + if _client is None: + token = os.environ.get('GITHUB_PERSONAL_ACCESS_TOKEN') + if not token: + raise RuntimeError( + 'GITHUB_PERSONAL_ACCESS_TOKEN not set. ' + 'Create a token at https://github.com/settings/tokens' + ) + _client = httpx.Client( + base_url='https://api.github.com', + headers={ + 'Authorization': f'token {token}', + 'Accept': 'application/vnd.github+json', + 'X-GitHub-Api-Version': '2022-11-28', + }, + timeout=30.0, + ) + return _client + + +def _parse_ts(ts: str) -> datetime: + """Parse a GitHub API timestamp like '2025-01-15T14:22:00Z'.""" + return datetime.fromisoformat(ts.replace('Z', '+00:00')) + + +# ============================================================================= +# Tool Call Tracking +# ============================================================================= + +_tool_calls: list[str] = [] + + +# ============================================================================= +# Tool Functions +# ============================================================================= + + +def list_pull_requests( + repo: str, + state: str = 'closed', + sort: str = 'updated', + direction: str = 'desc', + per_page: int = 30, + page: int = 1, +) -> dict[str, Any]: + """List pull requests for a GitHub repository. + + Args: + repo: Repository in "owner/repo" format (e.g. "pydantic/pydantic-ai"). + state: Filter by state: "open", "closed", or "all". + sort: Sort by: "created", "updated", or "popularity". + direction: Sort direction: "asc" or "desc". + per_page: Number of results per page (max 100). + page: Page number (1-based). + + Returns: + Pull request objects. Merged PRs include an hours_to_merge field. + """ + _tool_calls.append('list_pull_requests') + resp = _github_client().get( + f'/repos/{repo}/pulls', + params={ + 'state': state, + 'sort': sort, + 'direction': direction, + 'per_page': per_page, + 'page': page, + }, + ) + resp.raise_for_status() + prs = resp.json() + for pr in prs: + if pr.get('merged_at') and pr.get('created_at'): + delta = _parse_ts(pr['merged_at']) - _parse_ts(pr['created_at']) + pr['hours_to_merge'] = round(delta.total_seconds() / 3600, 1) + return { + 'page': page, + 'per_page': per_page, + 'count': len(prs), + 'pull_requests': prs, + } + + +def get_pr_reviews(repo: str, pr_number: int) -> dict[str, Any]: + """Get all reviews for a pull request. + + Args: + repo: Repository in "owner/repo" format. + pr_number: The pull request number. + + Returns: + Review objects with reviewer login, state, and submitted_at timestamp. + """ + _tool_calls.append('get_pr_reviews') + resp = _github_client().get(f'/repos/{repo}/pulls/{pr_number}/reviews') + resp.raise_for_status() + return {'pr_number': pr_number, 'reviews': resp.json()} + + +def get_pr_files(repo: str, pr_number: int) -> dict[str, Any]: + """Get files changed in a pull request. + + Args: + repo: Repository in "owner/repo" format. + pr_number: The pull request number. + + Returns: + Changed file objects with filename, status, additions, and deletions. + """ + _tool_calls.append('get_pr_files') + resp = _github_client().get( + f'/repos/{repo}/pulls/{pr_number}/files', + params={'per_page': 100}, + ) + resp.raise_for_status() + files = resp.json() + # Drop patch diffs -- they can be enormous and aren't needed for the analysis. + for f in files: + f.pop('patch', None) + return {'pr_number': pr_number, 'files': files} + + +# ============================================================================= +# Toolset & Agent Factories +# ============================================================================= + + +def create_toolset() -> FunctionToolset[None]: + toolset: FunctionToolset[None] = FunctionToolset() + toolset.add_function(list_pull_requests) + toolset.add_function(get_pr_reviews) + toolset.add_function(get_pr_files) + return toolset + + +def create_tool_calling_agent(toolset: FunctionToolset[None]) -> Agent[None, str]: + return Agent(MODEL, toolsets=[toolset], system_prompt=SYSTEM_PROMPT) + + +def create_code_execution_agent(toolset: FunctionToolset[None]) -> Agent[None, str]: + code_toolset: CodeExecutionToolset[None] = CodeExecutionToolset( + MontyEnvironment(), + toolset=toolset, + max_retries=MAX_RETRIES, + ) + return Agent(MODEL, toolsets=[code_toolset], system_prompt=SYSTEM_PROMPT) + + +# ============================================================================= +# Metrics +# ============================================================================= + + +@dataclass +class RunMetrics: + mode: str + request_count: int + input_tokens: int + output_tokens: int + retry_count: int + tool_calls: int + output: str + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +def extract_metrics( + result: AgentRunResult[str], mode: str, tool_calls: int +) -> RunMetrics: + request_count = 0 + input_tokens = 0 + output_tokens = 0 + retry_count = 0 + for msg in result.all_messages(): + if isinstance(msg, ModelResponse): + request_count += 1 + if msg.usage: + input_tokens += msg.usage.input_tokens or 0 + output_tokens += msg.usage.output_tokens or 0 + for part in getattr(msg, 'parts', []): + if isinstance(part, RetryPromptPart): + retry_count += 1 + return RunMetrics( + mode=mode, + request_count=request_count, + input_tokens=input_tokens, + output_tokens=output_tokens, + retry_count=retry_count, + tool_calls=tool_calls, + output=result.output, + ) + + +# ============================================================================= +# Run Functions +# ============================================================================= + + +async def run_tool_calling(toolset: FunctionToolset[None]) -> RunMetrics | None: + """Run with standard tool calling. Returns None if the context overflows.""" + _tool_calls.clear() + try: + with logfire.span('tool_calling'): + agent = create_tool_calling_agent(toolset) + result = await agent.run(PROMPT) + return extract_metrics(result, 'tool_calling', len(_tool_calls)) + except Exception as e: + error_str = str(e) + if 'too long' in error_str or 'too many tokens' in error_str.lower(): + logfire.error( + 'tool_calling failed: context window overflow after {tool_calls} API calls', + tool_calls=len(_tool_calls), + error=error_str, + ) + return None + raise + + +async def run_code_execution(toolset: FunctionToolset[None]) -> RunMetrics: + _tool_calls.clear() + with logfire.span('code_execution'): + agent = create_code_execution_agent(toolset) + code_toolset = agent.toolsets[0] + async with code_toolset: + result = await agent.run(PROMPT) + return extract_metrics(result, 'code_execution', len(_tool_calls)) + + +# ============================================================================= +# Output +# ============================================================================= + + +def log_metrics(metrics: RunMetrics) -> None: + logfire.info( + '{mode}: {requests} requests, {tokens} tokens, {tool_calls} API calls', + mode=metrics.mode, + requests=metrics.request_count, + tokens=metrics.total_tokens, + input_tokens=metrics.input_tokens, + output_tokens=metrics.output_tokens, + tool_calls=metrics.tool_calls, + retries=metrics.retry_count, + ) + + +# ============================================================================= +# Main +# ============================================================================= + + +async def main() -> None: + logfire.configure(service_name='code-execution-github-pr-analysis') + logfire.instrument_pydantic_ai() + + toolset = create_toolset() + + with logfire.span('demo_tool_calling'): + trad = await run_tool_calling(toolset) + if trad: + log_metrics(trad) + + with logfire.span('demo_code_execution'): + code = await run_code_execution(toolset) + log_metrics(code) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/examples/code_execution/pr_comment_buckets.py b/examples/code_execution/pr_comment_buckets.py new file mode 100644 index 0000000..678e3c0 --- /dev/null +++ b/examples/code_execution/pr_comment_buckets.py @@ -0,0 +1,288 @@ +"""Code Execution Example: Multi-Repo PR Deep Analysis via GitHub MCP. + +Demonstrates code execution's advantage over traditional tool calling -- +even when the model makes parallel tool calls. + +The task requires a dependent fan-out: + + Level 1 (N=3): List closed PRs for each repo -> 3 calls + Level 2 (N*M*Z = 3*5*4 = 60): Per PR, fetch files, -> 60 calls + reviews, review comments, issue comments + -------- + 63 total + +Traditional parallel tool calling (best case): + Roundtrip 1: 3 list-PRs calls fire in parallel -> 3 JSON results enter context + Roundtrip 2: Model sees PR numbers, fires 60 detail calls in parallel + -> 60 JSON results enter context (files, reviews, comments for + 15 PRs -- easily 100k+ tokens of intermediate data) + Roundtrip 3: Model reads ALL 63 results and tries to aggregate mentally + = 3 roundtrips, 63 tool results in context, ~100-200k tokens of raw JSON + +Code execution: + Roundtrip 1: Model writes async code with nested asyncio.gather, all 63 + API calls happen inside the sandbox, data is aggregated with + deterministic Python, only the final summary string is returned + Roundtrip 2: Model formats the summary as text + = 2 roundtrips, ~1k tokens of intermediate data + +The wins stack: + - Context: 63 JSON payloads in context (traditional) vs ~1k token summary (code exec) + - Cost: 100-200k tokens of I/O (traditional) vs ~5k tokens total (code exec) + - Accuracy: Deterministic code aggregation vs model "mental math" over 60 JSON blobs + - Latency: 3 serial roundtrips (traditional) vs 2 (code exec) + +Requires: + GITHUB_PERSONAL_ACCESS_TOKEN environment variable. + +Run: + uv run -m examples.code_execution.pr_comment_buckets +""" + +from __future__ import annotations + +import asyncio +import os +from dataclasses import dataclass + +import logfire + +from pydantic_ai import Agent +from pydantic_harness.environments.monty import MontyEnvironment +from pydantic_ai.mcp import MCPServerStreamableHTTP +from pydantic_ai.messages import ModelResponse, RetryPromptPart +from pydantic_ai.run import AgentRunResult +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset + +# ============================================================================= +# Configuration +# ============================================================================= + +REPOS = ['pydantic/pydantic', 'pydantic/pydantic-ai', 'pydantic/logfire'] +PRS_PER_REPO = 5 + +# The prompt is designed to require cross-repo aggregation that's painful +# for a model to do "in its head" from 60+ JSON payloads, but trivial in code. +PROMPT = """\ +Analyze the {prs_per_repo} most recent closed PRs in EACH of these repositories: +{repos} + +For EACH PR across all repos, fetch ALL of the following: +1. Files changed (per_page=100, page=1) +2. Reviews (per_page=100, page=1) +3. Review comments -- the line-level code comments (per_page=100, page=1) +4. Issue comments -- the general discussion thread (per_page=100, page=1) + +That's {total_detail_calls} detail API calls across {total_prs} PRs. Use asyncio.gather +aggressively -- fan out repo fetches, then fan out all detail fetches for all PRs at once. + +From the collected data, compute and return ONLY these metrics (no raw data): + +PER-REPO BREAKDOWN: +- Repo name +- Avg files changed per PR +- Avg reviews per PR +- Avg review comments (line-level) per PR +- Avg issue comments (discussion) per PR +- Total engagement score: sum of (reviews + review_comments + issue_comments) across all PRs +- File categories: tests (path contains "test") / docs (path contains "doc" or ends .md) / source (everything else) +- Review verdicts: approved / changes_requested / commented / dismissed + +CROSS-REPO COMPARISON: +- Review engagement ratio per repo: total_engagement / total_files_changed +- Rank repos by engagement ratio (highest = most discussion per line of code changed) + +HOT FILES (appear in 2+ PRs across ANY repo): +- List file paths that were modified in multiple PRs, with the count + +HOTTEST PR: +- The single PR with the highest score = (reviews + review_comments + issue_comments + 1) * files_changed +- Include: repo, PR number, title, and the score breakdown + +TOP 5 MOST-DISCUSSED PRs: +- Ranked by (review_comments + issue_comments), with repo, PR number, title, and counts + +REVIEWER LEADERBOARD: +- Top 5 reviewers by total reviews given across all repos, with per-repo breakdown +""" + +MODEL = 'gateway/anthropic:claude-sonnet-4-5' +MAX_RETRIES = 5 + +SYSTEM_PROMPT = ( + 'You are a GitHub analyst. Use the available tools to fetch data and compute metrics. ' + 'Do ALL data fetching and aggregation inside run_code -- return only the final summary as text in your response, ' + 'not as code output. The point is to avoid polluting your context window with raw API data.' +) + +# ============================================================================= +# GitHub MCP +# ============================================================================= + + +def create_github_mcp() -> MCPServerStreamableHTTP: + """Create GitHub MCP server connection.""" + token = os.environ.get('GITHUB_PERSONAL_ACCESS_TOKEN') + if not token: + raise ValueError('GITHUB_PERSONAL_ACCESS_TOKEN environment variable is required') + + return MCPServerStreamableHTTP( + url='https://api.githubcopilot.com/mcp/', + timeout=30, + headers={ + 'Authorization': f'Bearer {token}', + 'X-MCP-Toolsets': 'issues,pull_requests', + 'X-MCP-Readonly': 'true', + }, + ) + + +# ============================================================================= +# Agent Factories +# ============================================================================= + + +def create_tool_calling_agent(github: MCPServerStreamableHTTP) -> Agent[None, str]: + """Create agent with standard parallel tool calling. + + Even with parallel calls, the model needs: + Roundtrip 1: list PRs per repo (3 calls) + Roundtrip 2: fetch details per PR (60 calls) -- results ALL enter context + Roundtrip 3: produce analysis from 63 JSON blobs in context + """ + return Agent(MODEL, toolsets=[github], system_prompt=SYSTEM_PROMPT) + + +def create_code_execution_agent(github: MCPServerStreamableHTTP) -> Agent[None, str]: + """Create agent with code execution. + + The model writes a single code block that: + - Fans out all 63 API calls with asyncio.gather + - Aggregates results in-memory with deterministic Python + - Returns only the summary string + Only ~1k tokens of tool I/O enter the context. + """ + code_toolset: CodeExecutionToolset[None] = CodeExecutionToolset( + MontyEnvironment(), + toolset=github, + max_retries=MAX_RETRIES, + ) + return Agent(MODEL, toolsets=[code_toolset], system_prompt=SYSTEM_PROMPT) + + +# ============================================================================= +# Metrics +# ============================================================================= + + +@dataclass +class RunMetrics: + mode: str + request_count: int + input_tokens: int + output_tokens: int + retry_count: int + output: str + + @property + def total_tokens(self) -> int: + return self.input_tokens + self.output_tokens + + +def extract_metrics(result: AgentRunResult[str], mode: str) -> RunMetrics: + request_count = 0 + input_tokens = 0 + output_tokens = 0 + retry_count = 0 + + for msg in result.all_messages(): + if isinstance(msg, ModelResponse): + request_count += 1 + if msg.usage: + input_tokens += msg.usage.input_tokens or 0 + output_tokens += msg.usage.output_tokens or 0 + for part in getattr(msg, 'parts', []): + if isinstance(part, RetryPromptPart): + retry_count += 1 + + return RunMetrics( + mode=mode, + request_count=request_count, + input_tokens=input_tokens, + output_tokens=output_tokens, + retry_count=retry_count, + output=result.output, + ) + + +def print_metrics(metrics: RunMetrics) -> None: + print(f'\n{"=" * 70}') + print(f' {metrics.mode.upper()}') + print(f'{"=" * 70}') + print(f' LLM roundtrips: {metrics.request_count}') + print(f' Input tokens: {metrics.input_tokens:,}') + print(f' Output tokens: {metrics.output_tokens:,}') + print(f' Total tokens: {metrics.total_tokens:,}') + print(f' Retries: {metrics.retry_count}') + print(f'{"=" * 70}') + print(f'\n{metrics.output}\n') + + +# ============================================================================= +# Run +# ============================================================================= + +TOTAL_PRS = len(REPOS) * PRS_PER_REPO +DETAIL_CALLS_PER_PR = 4 # files, reviews, review comments, issue comments +TOTAL_DETAIL_CALLS = TOTAL_PRS * DETAIL_CALLS_PER_PR +TOTAL_CALLS = len(REPOS) + TOTAL_DETAIL_CALLS + +FORMATTED_PROMPT = PROMPT.format( + prs_per_repo=PRS_PER_REPO, + repos='\n'.join(f'- {r}' for r in REPOS), + total_prs=TOTAL_PRS, + total_detail_calls=TOTAL_DETAIL_CALLS, +) + + +async def run_tool_calling(github: MCPServerStreamableHTTP) -> RunMetrics: + """Run with standard parallel tool calling.""" + with logfire.span('tool_calling'): + agent = create_tool_calling_agent(github) + result = await agent.run(FORMATTED_PROMPT) + return extract_metrics(result, 'traditional parallel tool calling') + + +async def run_code_execution(github: MCPServerStreamableHTTP) -> RunMetrics: + """Run with code execution.""" + with logfire.span('code_execution'): + agent = create_code_execution_agent(github) + code_toolset = agent.toolsets[0] + async with code_toolset: + result = await agent.run(FORMATTED_PROMPT) + return extract_metrics(result, 'code execution') + + +# ============================================================================= +# Main +# ============================================================================= + + +async def main() -> None: + logfire.configure(service_name='code-execution-pr-analysis') + logfire.instrument_pydantic_ai() + + github = create_github_mcp() + + # Traditional parallel tool calling first + async with github: + trad = await run_tool_calling(github) + print_metrics(trad) + + # Code execution (CodeExecutionToolset.__aenter__ enters the MCP server internally) + code = await run_code_execution(github) + print_metrics(code) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/pydantic_harness/__init__.py b/pydantic_harness/__init__.py new file mode 100644 index 0000000..af4c833 --- /dev/null +++ b/pydantic_harness/__init__.py @@ -0,0 +1,58 @@ +"""pydantic-harness: execution environments and code mode capabilities for pydantic-ai.""" + +from pydantic_harness.capabilities.code_mode import CodeMode +from pydantic_harness.capabilities.execution_env import ExecutionEnv +from pydantic_harness.environments._base import ( + ExecutionEnvironment, + ExecutionProcess, + ExecutionResult, + FileInfo, +) +from pydantic_harness.toolsets.code_execution import ( + CodeExecutionToolset, + DescriptionFunc, + EnvironmentName, + FunctionSignature, + TypeSignature, + build_default_description, + get_environment, +) +from pydantic_harness.toolsets.code_execution._abstract import ( + CodeExecutionError, + CodeExecutionTimeout, + CodeRuntimeError, + CodeSyntaxError, + CodeTypingError, + FunctionCall, + FunctionCallback, +) +from pydantic_harness.toolsets.execution_environment import ExecutionEnvironmentToolset + +__all__ = ( + # Capabilities + 'CodeMode', + 'ExecutionEnv', + # Environments + 'ExecutionEnvironment', + 'ExecutionProcess', + 'ExecutionResult', + 'FileInfo', + # Toolsets + 'CodeExecutionToolset', + 'ExecutionEnvironmentToolset', + 'build_default_description', + 'get_environment', + # Type signatures + 'FunctionSignature', + 'TypeSignature', + 'DescriptionFunc', + 'EnvironmentName', + # Error types & callbacks + 'CodeExecutionError', + 'CodeExecutionTimeout', + 'CodeRuntimeError', + 'CodeSyntaxError', + 'CodeTypingError', + 'FunctionCall', + 'FunctionCallback', +) diff --git a/pydantic_harness/_python_signature.py b/pydantic_harness/_python_signature.py new file mode 100644 index 0000000..a437460 --- /dev/null +++ b/pydantic_harness/_python_signature.py @@ -0,0 +1,814 @@ +# TODO: When PR #4755 lands in pydantic-ai-slim, remove this file and +# switch imports to `from pydantic_ai._python_signature import ...` +"""Generate Python function signatures from functions and JSON schemas. + +This module provides utilities to represent tool definitions as human-readable +Python function signatures, which LLMs can understand more easily than raw +JSON schemas. Used by code mode to present tools as callable Python functions. +""" + +from __future__ import annotations + +import json +import re +import types +from collections.abc import Callable +from dataclasses import dataclass, field +from inspect import Parameter, Signature as InspectSignature, signature +from typing import Any, Literal, TypeAlias, Union, cast, get_origin + +from pydantic import BaseModel, TypeAdapter +from typing_extensions import get_type_hints + +from pydantic_ai._run_context import RunContext # private API +from pydantic_ai._utils import is_model_like # private API + + +def _get_schema_from_type(t: Any, *, mode: Literal['validation', 'serialization'] = 'validation') -> dict[str, Any]: + """Extract JSON schema from a `BaseModel`, dataclass, or `TypedDict`.""" + if isinstance(t, type) and issubclass(t, BaseModel): + return t.model_json_schema(mode=mode) + return TypeAdapter(t).json_schema(mode=mode) # pyright: ignore[reportUnknownArgumentType] + + +# ============================================================================= +# Type expression tree +# ============================================================================= + + +@dataclass +class GenericTypeExpr: + """A generic type expression like `list[User]`, `dict[str, User]`, `tuple[int, str]`.""" + + base: str + args: list[TypeExpr] + + def __str__(self) -> str: + return f'{self.base}[{", ".join(render_type_expr(a) for a in self.args)}]' + + +@dataclass +class UnionTypeExpr: + """A union type expression like `User | None`, `str | int`.""" + + members: list[TypeExpr] + + def __str__(self) -> str: + return ' | '.join(render_type_expr(m) for m in self.members) + + +# ============================================================================= +# Signature dataclasses +# ============================================================================= + + +def _render_docstring(text: str, indent: str = '') -> list[str]: + """Render a docstring as a list of indented lines.""" + text = text.strip() + if '\n' in text: + lines = [f'{indent}"""'] + for line in text.split('\n'): + lines.append(f'{indent}{line}' if line.strip() else '') + lines.append(f'{indent}"""') + return lines + return [f'{indent}"""{text}"""'] + + +@dataclass +class TypeFieldSignature: + name: str + type: TypeExpr + required: bool + description: str | None + + def __str__(self) -> str: + """Render this field as a line in a TypedDict class body.""" + type_str = render_type_expr(self.type) + if not self.required: + type_str = f'NotRequired[{type_str}]' + lines: list[str] = [f' {self.name}: {type_str}'] + if self.description: + lines.extend(_render_docstring(self.description, indent=' ')) + return '\n'.join(lines) + + +@dataclass +class TypeSignature: + name: str + + docstring: str | None = None + + fields: dict[str, TypeFieldSignature] = field(default_factory=dict[str, TypeFieldSignature]) + + def __str__(self) -> str: + """Render this type as a TypedDict class definition.""" + lines = [f'class {self.name}(TypedDict):'] + if self.docstring: + lines.extend(_render_docstring(self.docstring, indent=' ')) + if not self.fields: + if not self.docstring: + lines.append(' pass') + else: + for f in self.fields.values(): + lines.append(str(f)) + return '\n'.join(lines) + + def structurally_equal(self, other: TypeSignature) -> bool: + """Compare two TypeSignatures structurally, ignoring descriptions and docstrings.""" + if set(self.fields.keys()) != set(other.fields.keys()): + return False + for name, f in self.fields.items(): + other_f = other.fields[name] + if f.required != other_f.required: + return False + if render_type_expr(f.type) != render_type_expr(other_f.type): + return False + return True + + +TypeExpr: TypeAlias = 'TypeSignature | str | GenericTypeExpr | UnionTypeExpr' +"""A type expression that can reference `TypeSignature` objects, enabling automatic name propagation.""" + + +def render_type_expr(expr: TypeExpr) -> str: + """Render a type expression to a string.""" + if isinstance(expr, TypeSignature): + return expr.name + if isinstance(expr, str): + return expr + return str(expr) + + +@dataclass +class FunctionParam: + name: str + type: TypeExpr + default: str | None + + def __str__(self) -> str: + """Render this parameter as a function parameter string.""" + type_str = render_type_expr(self.type) + if self.default is not None: + return f'{self.name}: {type_str} = {self.default}' + return f'{self.name}: {type_str}' + + +@dataclass +class FunctionSignature: + """Python function signature with TypedDict definitions. + + This class holds all the data needed to render a function signature as Python code. + Use `str(sig)` for the default rendering, or call specific methods for variants. + """ + + name: str + """The function name.""" + + params: dict[str, FunctionParam] + """Function parameters.""" + + return_type: TypeExpr + """The return type expression.""" + + docstring: str | None = None + """Optional docstring for the function.""" + + referenced_types: list[TypeSignature] = field(default_factory=list[TypeSignature]) + """TypedDict class definitions needed by the signature.""" + + is_async: bool = True + """Whether to generate 'async def' (True) or 'def' (False).""" + + def __str__(self) -> str: + """Render with `...` body.""" + return self.render('...') + + def render(self, body: str) -> str: + """Render the signature with a specific body.""" + prefix = 'async def' if self.is_async else 'def' + params_str = ', '.join(str(p) for p in self.params.values()) + + return_str = render_type_expr(self.return_type) + if params_str: + # Force keyword-only params so LLMs always use named arguments + parts = [f'{prefix} {self.name}(*, {params_str}) -> {return_str}:'] + else: + parts = [f'{prefix} {self.name}() -> {return_str}:'] + + if self.docstring: + parts.extend(_render_docstring(self.docstring, indent=' ')) + + parts.append(f' {body}') + + return '\n'.join(parts) + + +# ============================================================================= +# Type annotation to TypeExpr conversion (Python annotations) +# ============================================================================= + + +def _get_type_name(t: Any) -> str: + """Get the name of a type.""" + if t is type(None): + return 'None' + if hasattr(t, '__name__'): + return t.__name__ + s = repr(t) + return s.replace('typing.', '').replace('typing_extensions.', '') + + +def _annotation_to_type_expr( + annotation: Any, + referenced_types: dict[str, TypeSignature], +) -> TypeExpr: + """Convert a Python type annotation to a TypeExpr.""" + if annotation is None or annotation is type(None): + return 'None' + + # Named types (BaseModel/TypedDict/dataclass) → look up in referenced_types + if is_model_like(annotation): + type_name = annotation.__name__ + if type_name in referenced_types: + return referenced_types[type_name] + return type_name + + # Handle Python 3.10+ union syntax (X | Y creates types.UnionType) + if isinstance(annotation, types.UnionType): + args = getattr(annotation, '__args__', ()) + members = [_annotation_to_type_expr(arg, referenced_types) for arg in args] + return UnionTypeExpr(members=members) + + origin = getattr(annotation, '__origin__', None) + args = getattr(annotation, '__args__', None) + + if origin is not None: + if args: + if origin is Union: + members = [_annotation_to_type_expr(arg, referenced_types) for arg in args] + return UnionTypeExpr(members=members) + base = _get_type_name(origin) + type_args = [_annotation_to_type_expr(arg, referenced_types) for arg in args] + return GenericTypeExpr(base=base, args=type_args) + return _get_type_name(origin) + + return _get_type_name(annotation) + + +# ============================================================================= +# Function signature builder (from Python functions) +# ============================================================================= + + +def _collect_referenced_types( + annotation: Any, + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str = '', + *, + mode: Literal['validation', 'serialization'] = 'validation', +) -> None: + """Recursively collect TypeSignature definitions from type annotations.""" + if annotation is None or annotation is type(None): + return + + if is_model_like(annotation): + type_name = annotation.__name__ + if type_name not in referenced_types: + schema = _get_schema_from_type(annotation, mode=mode) + schema_defs = schema.get('$defs', {}) + + # Process any $defs first + for def_name, def_schema in schema_defs.items(): + if ( + def_name not in referenced_types + and def_schema.get('type') == 'object' + and 'properties' in def_schema + ): + _build_and_register_type(def_name, def_schema, schema_defs, referenced_types, tool_name, path) + + # Then process the main schema + if schema.get('type') == 'object' and 'properties' in schema: + _build_and_register_type(type_name, schema, schema_defs, referenced_types, tool_name, path) + elif '$ref' in schema: + ref_name = schema['$ref'].split('/')[-1] + if ref_name in schema_defs and ref_name not in referenced_types: # pragma: no cover + ref_schema = schema_defs[ref_name] + if ref_schema.get('type') == 'object' and 'properties' in ref_schema: + referenced_types[ref_name] = _build_type_signature( + ref_name, ref_schema, schema_defs, referenced_types, tool_name, path + ) + return + + origin = get_origin(annotation) + args = getattr(annotation, '__args__', None) + if origin is not None and args: + for arg in args: + _collect_referenced_types(arg, referenced_types, tool_name, path, mode=mode) + + +def _build_function_params( + sig: InspectSignature, + type_hints: dict[str, Any], + referenced_types: dict[str, TypeSignature], + tool_name: str, +) -> dict[str, FunctionParam]: + """Build FunctionParam objects from a function's signature and type hints.""" + params: dict[str, FunctionParam] = {} + for i, (param_name, param) in enumerate(sig.parameters.items()): + annotation = type_hints.get(param_name) + + if i == 0 and annotation is not None and (annotation is RunContext or get_origin(annotation) is RunContext): + continue + + if annotation is not None: + _collect_referenced_types(annotation, referenced_types, tool_name, param_name) + + if annotation is not None: + type_expr = _annotation_to_type_expr(annotation, referenced_types) + else: + type_expr = 'Any' + + if param.default is Parameter.empty: + params[param_name] = FunctionParam(name=param_name, type=type_expr, default=None) + else: + default_str = repr(param.default) + params[param_name] = FunctionParam(name=param_name, type=type_expr, default=default_str) + return params + + +def function_to_signature( + func: Callable[..., Any], + name: str | None = None, + description: str | None = None, +) -> FunctionSignature: + """Build Signature from a Python function using inspect.""" + name = name or func.__name__ + sig = signature(func) + + try: + type_hints = get_type_hints(func, include_extras=True) + except (NameError, TypeError, AttributeError): + type_hints = {} + + referenced_types: dict[str, TypeSignature] = {} + params = _build_function_params(sig, type_hints, referenced_types, name) + + return_annotation = type_hints.get('return') + if return_annotation is not None: + _collect_referenced_types(return_annotation, referenced_types, name, 'Return', mode='serialization') + + return_type: TypeExpr = ( + _annotation_to_type_expr(return_annotation, referenced_types) if return_annotation else 'Any' + ) + + return FunctionSignature( + name=name, + params=params, + return_type=return_type, + docstring=description, + referenced_types=list(referenced_types.values()), + ) + + +# ============================================================================= +# JSON schema to signature conversion +# ============================================================================= + + +_JSON_TYPE_TO_PYTHON: dict[str, str] = { + 'string': 'str', + 'integer': 'int', + 'number': 'float', + 'boolean': 'bool', + 'null': 'None', + 'array': 'list', + 'object': 'dict', +} + + +def _json_type_to_python(json_type: str) -> str: + """Convert a JSON type string to Python type.""" + return _JSON_TYPE_TO_PYTHON.get(json_type, 'Any') + + +def _to_pascal_case(s: str) -> str: + """Convert a string to PascalCase.""" + s = re.sub(r'[^a-zA-Z0-9]', '_', s) + parts = s.split('_') + result = ''.join(part.capitalize() for part in parts if part) + if result and result[0].isdigit(): + result = '_' + result + return result + + +def _path_to_typename(tool_name: str, path: str) -> str: + """Convert a traversal path to a unique TypedDict name. + + Examples: + _path_to_typename('get_user', '') -> 'GetUser' + _path_to_typename('get_user', 'address') -> 'GetUserAddress' + _path_to_typename('get_user', 'home.address') -> 'GetUserHomeAddress' + """ + parts = [tool_name] + [p for p in path.split('.') if p] + return ''.join(_to_pascal_case(p) for p in parts) + + +def _process_schema_defs( + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, +) -> None: + """Process $defs from a JSON schema, populating referenced_types with TypeSignatures.""" + for def_name, def_schema in defs.items(): + if def_schema.get('type') == 'object' and 'properties' in def_schema: + if def_name not in referenced_types: + _build_and_register_type(def_name, def_schema, defs, referenced_types, tool_name, def_name) + + +def schema_to_signature( + name: str, + parameters_schema: dict[str, Any], + description: str | None = None, + return_type: str = 'Any', + return_schema: dict[str, Any] | None = None, +) -> FunctionSignature: + """Convert JSON schema to a FunctionSignature. + + Parameter and return schemas are processed independently — each resolves + $refs against its own $defs. Name collisions between parameter and return + types (e.g. both define a 'User' $def with different structures) are handled + by `dedup_referenced_types` at a later stage. + """ + # Process parameter schema with its own $defs + param_defs = parameters_schema.get('$defs', {}) + param_referenced: dict[str, TypeSignature] = {} + _process_schema_defs(param_defs, param_referenced, name) + params = _build_params_from_schema(parameters_schema, param_defs, param_referenced, name) + + # Process return schema independently (its own $defs) + resolved_return_type: TypeExpr = return_type + return_referenced: dict[str, TypeSignature] = {} + if return_schema is not None and return_type == 'Any': + return_defs = return_schema.get('$defs', {}) + _process_schema_defs(return_defs, return_referenced, name) + resolved_return_type = schema_to_type_expr(return_schema, return_defs, return_referenced, name, 'Return') + + # Handle case where return type couldn't be resolved + final_description = description + if return_schema is not None and resolved_return_type == 'Any': + return_schema_blob = json.dumps(return_schema, indent=2) + return_schema_note = f'\n\nReturn schema:\n{return_schema_blob}' + final_description = (description or '') + return_schema_note + final_description = final_description.strip() + + # Merge referenced types — dedup_referenced_types handles collisions later + all_referenced = list(param_referenced.values()) + list(return_referenced.values()) + + return FunctionSignature( + name=name, + params=params, + return_type=resolved_return_type, + docstring=final_description if final_description else None, + referenced_types=all_referenced, + ) + + +def _build_params_from_schema( + schema: dict[str, Any], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, +) -> dict[str, FunctionParam]: + """Convert a JSON schema to a dict of FunctionParam objects.""" + properties = schema.get('properties', {}) + required = set(schema.get('required', [])) + + required_params: dict[str, FunctionParam] = {} + optional_params: dict[str, FunctionParam] = {} + + for prop_name, prop_schema in properties.items(): + type_expr = schema_to_type_expr(prop_schema, defs, referenced_types, tool_name, prop_name) + + if 'default' in prop_schema: + default_str = repr(prop_schema['default']) + optional_params[prop_name] = FunctionParam(name=prop_name, type=type_expr, default=default_str) + elif prop_name in required: + required_params[prop_name] = FunctionParam(name=prop_name, type=type_expr, default=None) + else: + # Optional without default — add | None + if _schema_allows_null(prop_schema): + optional_params[prop_name] = FunctionParam(name=prop_name, type=type_expr, default='None') + else: + nullable_expr = UnionTypeExpr(members=[type_expr, 'None']) + optional_params[prop_name] = FunctionParam(name=prop_name, type=nullable_expr, default='None') + + return {**required_params, **optional_params} + + +def _schema_allows_null(schema: dict[str, Any]) -> bool: + """Check if a schema already allows null values.""" + schema_type = schema.get('type') + if isinstance(schema_type, list) and 'null' in schema_type: + return True + if 'anyOf' in schema or 'oneOf' in schema: + union = schema.get('anyOf') or schema.get('oneOf', []) + return any(s.get('type') == 'null' for s in union) + return False + + +def schema_to_type_expr( + schema: dict[str, Any], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str, +) -> TypeExpr: + """Convert a JSON schema to a TypeExpr.""" + # Handle $ref + if '$ref' in schema: + ref = schema['$ref'] + ref_name = ref.split('/')[-1] + # Ensure referenced def generates TypeSignature if needed + if ref_name in defs and ref_name not in referenced_types: + ref_schema = defs[ref_name] + if ref_schema.get('type') == 'object' and 'properties' in ref_schema: + _build_and_register_type(ref_name, ref_schema, defs, referenced_types, tool_name, path) + # Return the TypeSignature object if available, otherwise the name string + if ref_name in referenced_types: + return referenced_types[ref_name] + return ref_name + + # Handle anyOf/oneOf (union types) + if 'anyOf' in schema: + return _handle_union_schema(schema['anyOf'], defs, referenced_types, tool_name, path) + if 'oneOf' in schema: + return _handle_union_schema(schema['oneOf'], defs, referenced_types, tool_name, path) + + # Handle allOf + if 'allOf' in schema: + if len(schema['allOf']) == 1: + return schema_to_type_expr(schema['allOf'][0], defs, referenced_types, tool_name, path) + return 'Any' + + # Handle const + if 'const' in schema: + return f'Literal[{repr(schema["const"])}]' + + # Handle enum + if 'enum' in schema: + enum_values = ', '.join(repr(v) for v in schema['enum']) + return f'Literal[{enum_values}]' + + # Handle by type + schema_type = schema.get('type') + return _type_to_expr(schema_type, schema, defs, referenced_types, tool_name, path) + + +def _type_to_expr( + schema_type: str | list[str] | None, + schema: dict[str, Any], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str, +) -> TypeExpr: + """Convert a schema type to a TypeExpr.""" + # Simple types — use shared mapping, skip compound types handled below + if isinstance(schema_type, str) and schema_type in _JSON_TYPE_TO_PYTHON and schema_type not in ('array', 'object'): + return _JSON_TYPE_TO_PYTHON[schema_type] + + # Array type + if schema_type == 'array': + items = schema.get('items', {}) + if items: + # Handle tuple schemas (items as list) + if isinstance(items, list): + items_list = cast(list[dict[str, Any]], items) + item_exprs = [ + schema_to_type_expr(item, defs, referenced_types, tool_name, f'{path}.{i}') + for i, item in enumerate(items_list) + ] + return GenericTypeExpr(base='tuple', args=item_exprs) + item_expr = schema_to_type_expr( + cast(dict[str, Any], items), defs, referenced_types, tool_name, f'{path}Item' + ) + return GenericTypeExpr(base='list', args=[item_expr]) + return 'list[Any]' + + # Object type + if schema_type == 'object': + if 'properties' in schema: + # Generate TypeSignature with path-based unique name + td_name = _path_to_typename(tool_name, path) + if td_name not in referenced_types: + referenced_types[td_name] = _build_type_signature( + td_name, schema, defs, referenced_types, tool_name, path + ) + return referenced_types[td_name] + if 'additionalProperties' in schema: + additional = schema['additionalProperties'] + if additional is True: + return 'dict[str, Any]' + if isinstance(additional, dict): + additional_schema = cast(dict[str, Any], additional) + value_expr = schema_to_type_expr(additional_schema, defs, referenced_types, tool_name, f'{path}Value') + return GenericTypeExpr(base='dict', args=['str', value_expr]) + return 'dict[str, Any]' + + # Type list (e.g., ['string', 'null']) + if isinstance(schema_type, list): + return _type_list_to_expr(schema_type, schema, defs, referenced_types, tool_name, path) + + return 'Any' + + +def _type_list_to_expr( + schema_type: list[str], + schema: dict[str, Any], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str, +) -> TypeExpr: + """Handle type lists like ['string', 'null'].""" + # Check if this is object with properties + null + if 'object' in schema_type and 'properties' in schema: + base_expr = _type_to_expr('object', schema, defs, referenced_types, tool_name, path) + if 'null' in schema_type: + return UnionTypeExpr(members=[base_expr, 'None']) + return base_expr + + type_exprs: list[TypeExpr] = [_json_type_to_python(t) for t in schema_type] + type_exprs = [t for t in type_exprs if t] + if len(type_exprs) == 2 and 'None' in type_exprs: + non_none = [t for t in type_exprs if t != 'None'][0] + return UnionTypeExpr(members=[non_none, 'None']) + if type_exprs: + return UnionTypeExpr(members=type_exprs) if len(type_exprs) > 1 else type_exprs[0] + return 'Any' + + +def _handle_union_schema( + schemas: list[dict[str, Any]], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str, +) -> TypeExpr: + """Handle anyOf/oneOf schemas, returning a TypeExpr.""" + type_exprs: list[TypeExpr] = [] + has_null = False + + for s in schemas: + if s.get('type') == 'null': + has_null = True + else: + type_exprs.append(schema_to_type_expr(s, defs, referenced_types, tool_name, path)) + + # Deduplicate while preserving order (compare rendered strings) + seen: set[str] = set() + unique_exprs: list[TypeExpr] = [] + for expr in type_exprs: + rendered = render_type_expr(expr) + if rendered not in seen: + seen.add(rendered) + unique_exprs.append(expr) + + if has_null: + unique_exprs.append('None') + + if len(unique_exprs) == 1: + return unique_exprs[0] + return UnionTypeExpr(members=unique_exprs) + + +def _build_and_register_type( + name: str, + schema: dict[str, Any], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str, +) -> None: + """Build a TypeSignature and register it, pre-registering a placeholder to prevent infinite recursion. + + Self-referential schemas (e.g. recursive models) would otherwise cause infinite recursion + when `_build_type_signature` processes properties that `$ref` back to the same type. + """ + placeholder = TypeSignature(name=name) + referenced_types[name] = placeholder + built = _build_type_signature(name, schema, defs, referenced_types, tool_name, path) + placeholder.fields = built.fields + placeholder.docstring = built.docstring + + +def _build_type_signature( + name: str, + schema: dict[str, Any], + defs: dict[str, dict[str, Any]], + referenced_types: dict[str, TypeSignature], + tool_name: str, + path: str, +) -> TypeSignature: + """Build a TypeSignature from an object schema.""" + properties = schema.get('properties', {}) + required = set(schema.get('required', [])) + + fields: dict[str, TypeFieldSignature] = {} + + for prop_name, prop_schema in properties.items(): + prop_path = f'{path}.{prop_name}' if path else prop_name + type_expr = schema_to_type_expr(prop_schema, defs, referenced_types, tool_name, prop_path) + is_required = prop_name in required + desc = prop_schema.get('description', '') or None + + fields[prop_name] = TypeFieldSignature( + name=prop_name, + type=type_expr, + required=is_required, + description=desc, + ) + + docstring = schema.get('description') or None + return TypeSignature(name=name, docstring=docstring, fields=fields) + + +# ============================================================================= +# Deduplication +# ============================================================================= + + +def _replace_type_refs(sig: FunctionSignature, old_ref: TypeSignature, canonical: TypeSignature) -> None: + """Replace all references to old_ref with canonical in a signature's TypeExpr trees.""" + + def _replace_in_expr(expr: TypeExpr) -> TypeExpr: + if expr is old_ref: + return canonical + if isinstance(expr, GenericTypeExpr): + new_args = [_replace_in_expr(a) for a in expr.args] + if any(new is not orig for new, orig in zip(new_args, expr.args)): + expr.args = new_args + elif isinstance(expr, UnionTypeExpr): + new_members = [_replace_in_expr(m) for m in expr.members] + if any(new is not orig for new, orig in zip(new_members, expr.members)): + expr.members = new_members + return expr + + # Replace in params + for param in sig.params.values(): + param.type = _replace_in_expr(param.type) + + # Replace in return type + sig.return_type = _replace_in_expr(sig.return_type) + + # Replace in field types of referenced types + for type_sig in sig.referenced_types: + for f in type_sig.fields.values(): + f.type = _replace_in_expr(f.type) + + +def dedup_referenced_types(signatures: list[FunctionSignature]) -> None: + """Resolve TypedDict name conflicts across multiple tool signatures in place. + + Each signature keeps all its referenced types (so it remains self-contained), + but identical types (same name and structure) are unified to the same object + instance, and conflicting types (same name, different structure) are renamed + by prefixing the tool name to disambiguate. + + Use `collect_unique_referenced_types()` when rendering to emit each definition once. + """ + seen: dict[str, TypeSignature] = {} + + for sig in signatures: + deduped: list[TypeSignature] = [] + for type_sig in sig.referenced_types: + name = type_sig.name + if name not in seen: + # First occurrence — keep it + seen[name] = type_sig + deduped.append(type_sig) + elif seen[name].structurally_equal(type_sig): + # Same name, same definition — unify to canonical instance + canonical = seen[name] + _replace_type_refs(sig, type_sig, canonical) + deduped.append(canonical) + else: + # Same name, different definition — rename to avoid conflict + new_name = f'{sig.name}_{name}' + type_sig.name = new_name # Mutate in place → propagates everywhere via TypeExpr refs + seen[new_name] = type_sig + deduped.append(type_sig) + sig.referenced_types = deduped + + +def collect_unique_referenced_types(signatures: list[FunctionSignature]) -> list[TypeSignature]: + """Collect unique TypeSignature objects from signatures, deduplicating by identity.""" + seen_ids: set[int] = set() + result: list[TypeSignature] = [] + for sig in signatures: + for type_sig in sig.referenced_types: + if id(type_sig) not in seen_ids: + seen_ids.add(id(type_sig)) + result.append(type_sig) + return result diff --git a/pydantic_harness/capabilities/__init__.py b/pydantic_harness/capabilities/__init__.py new file mode 100644 index 0000000..5a2cf8c --- /dev/null +++ b/pydantic_harness/capabilities/__init__.py @@ -0,0 +1,6 @@ +"""Capabilities for pydantic-harness: CodeMode and ExecutionEnv.""" + +from .code_mode import CodeMode +from .execution_env import ExecutionEnv + +__all__ = ('CodeMode', 'ExecutionEnv') diff --git a/pydantic_harness/capabilities/code_mode.py b/pydantic_harness/capabilities/code_mode.py new file mode 100644 index 0000000..9c82c7a --- /dev/null +++ b/pydantic_harness/capabilities/code_mode.py @@ -0,0 +1,59 @@ +"""CodeMode capability — wraps CodeExecutionToolset as an AbstractCapability.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from pydantic_ai.capabilities import AbstractCapability +from pydantic_ai.tools import AgentDepsT +from pydantic_ai.toolsets import AbstractToolset + +from pydantic_harness.environments._base import ExecutionEnvironment +from pydantic_harness.toolsets.code_execution import ( + CodeExecutionToolset, + DescriptionFunc, + EnvironmentName, + build_default_description, +) + + +@dataclass +class CodeMode(AbstractCapability[AgentDepsT]): + """Capability that provides code execution via CodeExecutionToolset. + + Wraps an ExecutionEnvironment (or environment name like 'monty') and an optional + toolset of tools to expose as callable Python functions within the code sandbox. + + Usage: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness import CodeMode + + agent = Agent('anthropic:claude-sonnet-4-5', capabilities=[CodeMode()]) + ``` + """ + + environment: ExecutionEnvironment | EnvironmentName = 'monty' + """The code execution environment. Can be an instance or a string shorthand ('monty').""" + + toolset: AbstractToolset[AgentDepsT] | None = None + """Optional toolset to wrap. Its tools become callable Python functions in the sandbox.""" + + description: str | DescriptionFunc = field(default=build_default_description) + """Custom tool description. String or callback for full control.""" + + max_retries: int = 3 + """Maximum retries for code execution errors.""" + + @classmethod + def get_serialization_name(cls) -> str | None: + return 'CodeMode' + + def get_toolset(self) -> AbstractToolset[Any] | None: + return CodeExecutionToolset( + environment=self.environment, + toolset=self.toolset, + description=self.description, + max_retries=self.max_retries, + ) diff --git a/pydantic_harness/capabilities/execution_env.py b/pydantic_harness/capabilities/execution_env.py new file mode 100644 index 0000000..fe6b048 --- /dev/null +++ b/pydantic_harness/capabilities/execution_env.py @@ -0,0 +1,86 @@ +"""ExecutionEnv capability — wraps ExecutionEnvironmentToolset as an AbstractCapability.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from pydantic_ai.capabilities import AbstractCapability +from pydantic_ai.tools import AgentDepsT +from pydantic_ai.toolsets import AbstractToolset + +from pydantic_harness.environments._base import ExecutionEnvironment +from pydantic_harness.toolsets.execution_environment import ( + Capability, + CodeLanguage, + EditStrategy, + ExecutionEnvironmentToolset, +) + + +@dataclass +class ExecutionEnv(AbstractCapability[AgentDepsT]): + """Capability that provides coding-agent-style tools backed by an ExecutionEnvironment. + + Exposes ls, shell, read_file, write_file, edit, glob, grep tools. + + Usage: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness import ExecutionEnv + from pydantic_harness.environments.local import LocalEnvironment + + agent = Agent( + 'anthropic:claude-sonnet-4-5', + capabilities=[ExecutionEnv(environment=LocalEnvironment(root_dir='/tmp/workspace'))], + ) + ``` + """ + + environment: ExecutionEnvironment + """The execution environment backing the tools.""" + + include: frozenset[Capability] | None = None + """Capabilities to include. None = all (minus default excludes).""" + + exclude: frozenset[Capability] | None = None + """Capabilities to exclude. None = default excludes (run_code).""" + + edit_strategy: EditStrategy | None = None + """Edit tool strategy. None = auto-select.""" + + code_language: CodeLanguage | None = None + """Code execution language. None = auto-detect.""" + + require_shell_approval: bool = False + """Whether shell tool requires human approval.""" + + require_write_approval: bool = False + """Whether write/edit tools require human approval.""" + + image_support: bool = True + """Whether read_file returns images as BinaryContent.""" + + max_image_bytes: int = 50 * 1024 * 1024 + """Maximum image file size in bytes to return as BinaryContent.""" + + max_retries: int = 1 + """Maximum retries per tool call.""" + + @classmethod + def get_serialization_name(cls) -> str | None: + return 'ExecutionEnv' + + def get_toolset(self) -> AbstractToolset[Any] | None: + return ExecutionEnvironmentToolset( + environment=self.environment, + include=self.include, + exclude=self.exclude, + edit_strategy=self.edit_strategy, + code_language=self.code_language, + require_shell_approval=self.require_shell_approval, + require_write_approval=self.require_write_approval, + image_support=self.image_support, + max_image_bytes=self.max_image_bytes, + max_retries=self.max_retries, + ) diff --git a/pydantic_harness/environments/__init__.py b/pydantic_harness/environments/__init__.py new file mode 100644 index 0000000..d7d1237 --- /dev/null +++ b/pydantic_harness/environments/__init__.py @@ -0,0 +1,27 @@ +"""Execution environment abstractions for agents. + +This package provides: + +- `ExecutionEnvironment` — abstract base class for execution environments +- `ExecutionProcess` — interactive process handle with bidirectional I/O +- `ExecutionEnvironmentToolset` — toolset exposing coding-agent-style tools backed by an environment +- `ExecutionResult`, `FileInfo` — result types + +Implementations: + +- `environments.local.LocalEnvironment` — local subprocess environment (no isolation, for dev/testing) +- `environments.memory.MemoryEnvironment` — in-memory environment for testing +- `environments.monty.MontyEnvironment` — Monty sandboxed interpreter for code execution +""" + +from pydantic_harness.toolsets.execution_environment import ExecutionEnvironmentToolset + +from ._base import ExecutionEnvironment, ExecutionProcess, ExecutionResult, FileInfo + +__all__ = ( + 'ExecutionResult', + 'ExecutionEnvironment', + 'ExecutionEnvironmentToolset', + 'ExecutionProcess', + 'FileInfo', +) diff --git a/pydantic_harness/environments/_base.py b/pydantic_harness/environments/_base.py new file mode 100644 index 0000000..84b5a31 --- /dev/null +++ b/pydantic_harness/environments/_base.py @@ -0,0 +1,677 @@ +"""Base abstractions for execution environments. + +This module defines the core types, the `ExecutionEnvironment` ABC, and the +`ExecutionProcess` ABC for interactive execution with bidirectional streaming I/O. +""" + +from __future__ import annotations + +import fnmatch +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal + +from typing_extensions import Self + +if TYPE_CHECKING: + from pydantic_harness._python_signature import FunctionSignature, TypeSignature + from pydantic_harness.toolsets.code_execution._abstract import FunctionCallback + + +# --- Capability type alias --- + +Capability = Literal[ + 'ls', + 'shell', + 'read_file', + 'write_file', + 'replace_str', + 'apply_patch', + 'glob', + 'grep', + 'run_python', + 'run_python_with_functions', + 'run_typescript', + 'run_typescript_with_functions', +] +"""Fine-grained capability identifier listing actual method names. + +Used in `capabilities` to declare which methods an environment implements. +Toolsets are responsible for mapping these to LLM-facing tool names. +""" + + +# --- Data types --- + + +@dataclass +class ExecutionResult: + """Result of a completed command execution.""" + + output: str + """The combined stdout/stderr output of the command.""" + + exit_code: int + """The exit code of the command.""" + + truncated: bool = False + """Whether the output was truncated due to length limits.""" + + +@dataclass +class FileInfo: + """Metadata about a file or directory.""" + + name: str + """The file or directory name.""" + + path: str + """The full path.""" + + is_dir: bool + """Whether this entry is a directory.""" + + size: int | None = None + """The file size in bytes, or None for directories.""" + + +class ExecutionProcess(ABC): + r"""Handle to a running process with bidirectional streaming I/O. + + Used for interactive execution where a script outputs data, + waits for input, processes it, and outputs more data. + + This is the lower-level building block for "code mode" where a + running script exchanges data via pipes. + """ + + @abstractmethod + async def send(self, data: bytes) -> None: + """Write data to the process's stdin. + + Args: + data: The bytes to write to stdin. + """ + + @abstractmethod + async def recv(self, timeout: float | None = None) -> bytes: + """Read available output from stdout. + + Blocks until data is available, the process exits, or the timeout expires. + + Args: + timeout: Maximum seconds to wait for data. None means wait indefinitely. + + Raises: + TimeoutError: If the timeout expires with no data available. + """ + + @abstractmethod + async def recv_stderr(self, timeout: float | None = None) -> bytes: + """Read available output from stderr. + + Args: + timeout: Maximum seconds to wait for data. None means wait indefinitely. + + Raises: + TimeoutError: If the timeout expires with no data available. + """ + + @property + @abstractmethod + def returncode(self) -> int | None: + """Return code if the process has exited, None if still running.""" + + @abstractmethod + async def wait(self, timeout: float | None = None) -> int: + """Wait for the process to exit. + + Args: + timeout: Maximum seconds to wait. None means wait indefinitely. + + Returns: + The process exit code. + + Raises: + TimeoutError: If the timeout expires before the process exits. + """ + + @abstractmethod + async def kill(self) -> None: + """Kill the process.""" + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, *args: Any) -> None: + if self.returncode is None: + await self.kill() + + +# --- Constants --- + +IMAGE_EXTENSIONS = frozenset( + { + '.png', + '.jpg', + '.jpeg', + '.gif', + '.webp', + '.bmp', + '.svg', + } +) + +IMAGE_MEDIA_TYPES: dict[str, str] = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.webp': 'image/webp', + '.bmp': 'image/bmp', + '.svg': 'image/svg+xml', +} + +MAX_OUTPUT_CHARS = 100_000 + + +# --- ExecutionEnvironment --- + + +class ExecutionEnvironment(ABC): + """Abstract base class for execution environments. + + An execution environment provides a place where agents can execute + commands, read/write files, search the filesystem, and optionally + execute code. + + Implementations range from in-memory (for testing) to local subprocess, + Docker containers, and cloud-hosted VMs. + + The only abstract member is `capabilities`; all tool methods raise + `NotImplementedError` by default. Concrete subclasses override the + methods that match their declared capabilities. + """ + + # --- Capability introspection --- + + @property + @abstractmethod + def capabilities(self) -> frozenset[Capability]: + """Capabilities this environment supports (high-level). + + Used by toolsets to decide which tools to register. Only methods + corresponding to declared capabilities need to be implemented. + """ + ... + + def instructions(self, capability: Capability) -> str | None: + """Per-capability instructions for the LLM. + + Override to provide environment-specific hints that toolsets include + in the tool description shown to the model, e.g.:: + + def instructions(self, capability): + if capability == 'shell': + return 'Bash in Docker container, numpy/pandas installed' + if capability == 'grep': + return 'Uses POSIX basic regex, not Python re syntax' + return None + + Args: + capability: The capability name (e.g. `'shell'`, `'run_python'`). + + Returns: + Instruction text for the LLM, or None for no extra instructions. + """ + return None + + # --- Tool methods --- + # All raise NotImplementedError by default. Concrete subclasses override + # the methods that match their declared capabilities. + + async def ls(self, path: str = '.') -> list[FileInfo]: + """List directory contents. + + Args: + path: The directory path within the environment. + + Returns: + A list of `FileInfo` entries. + """ + raise NotImplementedError(f'{type(self).__name__} does not support ls.') + + async def shell( + self, + command: str, + *, + timeout: float | None = 120, + env: dict[str, str] | None = None, + ) -> ExecutionResult: + """Execute a shell command and return the result. + + Args: + command: The shell command to execute. + timeout: Maximum seconds to wait for completion. + Pass `None` to disable the timeout. + env: Additional environment variables for this command. + Merged with (and overrides) any baseline environment variables. + + Returns: + An `ExecutionResult` with the command output and exit code. + """ + raise NotImplementedError(f'{type(self).__name__} does not support shell.') + + async def read_file( + self, + path: str, + *, + offset: int = 0, + limit: int = 2000, + ) -> str | bytes: + """Read a file from the environment. + + For text files, returns a string with `cat -n` style line numbers. + For binary files (images), returns raw bytes. + + Args: + path: The file path within the environment. + offset: The line number to start reading from (0-indexed). + Ignored for binary files. + limit: Maximum number of lines to read. + Ignored for binary files. + + Returns: + Text content with line numbers (`str`), or raw bytes for binary files. + """ + raise NotImplementedError(f'{type(self).__name__} does not support read_file.') + + async def write_file(self, path: str, content: str | bytes) -> None: + """Create or overwrite a file in the environment. + + Args: + path: The file path within the environment. + content: The file content (text or binary). + """ + raise NotImplementedError(f'{type(self).__name__} does not support write_file.') + + async def replace_str( + self, + path: str, + old: str, + new: str, + *, + replace_all: bool = False, + ) -> int: + """Edit a file by exact string replacement. + + Args: + path: The file path within the environment. + old: The exact text to find. + new: The replacement text. + replace_all: If True, replace all occurrences. If False, the + old string must appear exactly once or an error is raised. + + Returns: + The number of replacements made. + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If `old` is not found, or appears multiple times + when `replace_all` is False. + """ + raise NotImplementedError(f'{type(self).__name__} does not support replace_str.') + + async def apply_patch(self, path: str, patch: str) -> str: + """Apply a unified diff patch to a file. + + Args: + path: The file path within the environment. + patch: The unified diff patch content. + + Returns: + The resulting file content after applying the patch. + """ + raise NotImplementedError(f'{type(self).__name__} does not support apply_patch.') + + async def glob(self, pattern: str, *, path: str = '.') -> list[str]: + """Find files matching a glob pattern. + + Args: + pattern: The glob pattern (e.g. `'**/*.py'`). + path: The directory to search in. + + Returns: + A list of matching file paths. + """ + raise NotImplementedError(f'{type(self).__name__} does not support glob.') + + async def grep( + self, + pattern: str, + *, + path: str | None = None, + glob_pattern: str | None = None, + output_mode: Literal['content', 'files_with_matches', 'count'] = 'content', + ) -> str: + """Search file contents with a regex pattern. + + Args: + pattern: The regex pattern to search for. + path: The file or directory to search in. + glob_pattern: Optional glob to filter which files are searched. + output_mode: Controls output format: + - `'content'` (default): matching lines as `file:line_number:text` + - `'files_with_matches'`: only file paths containing matches + - `'count'`: `file:count` pairs + + Returns: + Matching lines formatted as text. + """ + raise NotImplementedError(f'{type(self).__name__} does not support grep.') + + async def run_python(self, code: str) -> Any: + """Execute Python code. + + The default implementation writes the code to a temp file and runs + it via `shell()`. + + Args: + code: The Python code to execute. + + Returns: + The output of the code execution. + + Raises: + CodeRuntimeError: If the code exits with a non-zero status. + """ + await self.write_file('/tmp/_pydantic_ai_code.py', code) + result = await self.shell('python /tmp/_pydantic_ai_code.py') + if result.exit_code != 0: + from pydantic_harness.toolsets.code_execution._abstract import CodeRuntimeError + + raise CodeRuntimeError(result.output) + return result.output + + async def run_python_with_functions( + self, + code: str, + *, + function_callback: FunctionCallback, + functions: dict[str, FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> Any: + """Execute Python code with access to external functions. + + Environments that support this must include `'run_python_with_functions'` + in their `capabilities`. + + Args: + code: The Python code to execute. + function_callback: Callback invoked when the code calls an external function. + functions: Mapping of function name to signature (for type checking). + referenced_types: Type definitions referenced by the signatures. + + Returns: + The output of the code execution. + """ + raise NotImplementedError(f'{type(self).__name__} does not support run_python_with_functions.') + + async def run_typescript(self, code: str) -> Any: + """Execute TypeScript code. + + Not yet implemented. Reserved for future multi-language support. + """ + raise NotImplementedError(f'{type(self).__name__} does not support run_typescript.') + + async def run_typescript_with_functions( + self, + code: str, + *, + function_callback: FunctionCallback, + functions: dict[str, FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> Any: + """Execute TypeScript code with access to external functions. + + Not yet implemented. Reserved for future multi-language support. + """ + raise NotImplementedError(f'{type(self).__name__} does not support run_typescript_with_functions.') + + # --- State management --- + + def reset(self) -> None: + """Reset the environment to a clean state, discarding any accumulated REPL state. + + Called by toolsets when the model requests a session restart. Environments + that maintain persistent state (e.g. a REPL session) should override this + to destroy and recreate their state. Stateless environments can use the + default no-op. + """ + + def type_check( + self, + code: str, + *, + signatures: list[FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> None: + """Statically type-check code before execution. + + Only sound when the environment has no accumulated state (e.g. after a + `reset()`), since the type checker cannot see variables from prior calls. + + Override in subclasses that support type checking. The default is a no-op. + + Args: + code: The Python code to type-check. + signatures: Function signatures to include as type stubs. + referenced_types: Type definitions referenced by the signatures. + + Raises: + CodeTypingError: If type errors are found. + CodeSyntaxError: If the code has a syntax error. + """ + + # --- Internal helpers (not tools) --- + + async def create_process( + self, + command: str, + *, + env: dict[str, str] | None = None, + ) -> ExecutionProcess: + r"""Create an interactive process with streaming stdin/stdout. + + This is an internal helper for code execution drivers, not a tool. + + Args: + command: The shell command to run. + env: Additional environment variables for this process. + + Returns: + An `ExecutionProcess` handle for bidirectional I/O. + """ + raise NotImplementedError(f'{type(self).__name__} does not support interactive processes.') + + # --- Lifecycle --- + + async def __aenter__(self) -> Self: + """Start the environment (e.g., create a Docker container).""" + return self + + async def __aexit__(self, *args: Any) -> None: + """Stop the environment and clean up resources.""" + + +# --- Helper functions --- + + +def shell_escape(s: str) -> str: + """Escape a string for safe use in shell commands.""" + return "'" + s.replace("'", "'\\''") + "'" + + +def format_lines(text: str, offset: int, limit: int) -> str: + """Format text with line numbers and continuation hints. + + Shared helper used by `LocalEnvironment` and `MemoryEnvironment` + to produce consistent `cat -n` style output. + """ + lines = text.splitlines(keepends=True) + total_lines = len(lines) + + if offset >= total_lines and total_lines > 0: + raise ValueError(f'Offset {offset} exceeds file length ({total_lines} lines).') + + selected = lines[offset : offset + limit] + + numbered = [f'{i:>6}\t{line}' for i, line in enumerate(selected, start=offset + 1)] + result = ''.join(numbered) + if not result.endswith('\n'): + result += '\n' + + remaining = total_lines - (offset + len(selected)) + if remaining > 0: + next_offset = offset + len(selected) + result += f'... ({remaining} more lines. Use offset={next_offset} to continue reading.)\n' + + return result + + +def collect_grep_matches( + rel_path: str, + text: str, + compiled: re.Pattern[str], + output_mode: Literal['content', 'files_with_matches', 'count'], + results: list[str], +) -> None: + """Collect grep matches from a single file into `results`. + + Shared helper used by `LocalEnvironment` and `MemoryEnvironment`. + """ + if output_mode == 'files_with_matches': + if any(compiled.search(line) for line in text.splitlines()): + results.append(rel_path) + elif output_mode == 'count': + match_count = sum(1 for line in text.splitlines() if compiled.search(line)) + if match_count > 0: + results.append(f'{rel_path}:{match_count}') + else: + for line_num, line in enumerate(text.splitlines(), start=1): + if compiled.search(line): + results.append(f'{rel_path}:{line_num}:{line}') + + +def glob_match(path: str, pattern: str) -> bool: + """Match a path against a glob pattern with `**` support. + + `fnmatch` does not support `**` for recursive matching. + This helper converts glob patterns to regex so that `**` + matches zero or more path segments (including `/`). + """ + if '**' not in pattern: + return fnmatch.fnmatch(path, pattern) + + regex = '' + i = 0 + while i < len(pattern): + if pattern[i : i + 3] == '**/': + regex += '(.*/)?' + i += 3 + elif pattern[i : i + 2] == '**': + regex += '.*' + i += 2 + elif pattern[i] == '*': + regex += '[^/]*' + i += 1 + elif pattern[i] == '?': + regex += '[^/]' + i += 1 + else: + regex += re.escape(pattern[i]) + i += 1 + return bool(re.fullmatch(regex, path)) + + +# --- Shell command builders for Docker environments --- + + +def build_read_file_cmd(path: str, *, offset: int = 0, limit: int = 2000) -> str: + """Build a shell command that reads a file with line numbers. + + Uses `awk` for reliable line numbering that handles tabs correctly. + Includes a continuation hint when more lines remain, consistent + with the `format_lines` helper used by Local/Memory environments. + """ + escaped = shell_escape(path) + start = offset + 1 + end = offset + limit + return ( + f'awk \'NR>={start} && NR<={end} {{printf "%6d\\t%s\\n", NR, $0}}' + f' END {{if(NR>{end}) printf "... (%d more lines. Use offset={end} to continue reading.)\\n", NR-{end}}}\'' + f' {escaped}' + ) + + +def build_grep_cmd( + pattern: str, + *, + path: str | None = None, + glob_pattern: str | None = None, + output_mode: Literal['content', 'files_with_matches', 'count'] = 'content', +) -> str: + """Build a shell `grep` command from structured arguments.""" + parts = ['grep', '-rI'] # -I skips binary files + if output_mode == 'files_with_matches': + parts.append('-l') + elif output_mode == 'count': + parts.append('-c') + else: + parts.append('-n') + if glob_pattern: + parts.extend(['--include', shell_escape(glob_pattern)]) + parts.append(shell_escape(pattern)) + parts.append(shell_escape(path or '.')) + return ' '.join(parts) + + +def filter_grep_count_output(text: str) -> str: + """Filter `grep -c` output to remove files with 0 matches.""" + return '\n'.join(line for line in text.splitlines() if not line.endswith(':0')) + + +def build_glob_cmd(pattern: str, *, path: str = '.') -> str: + """Build a shell `find` command to match files by pattern.""" + return f'find {shell_escape(path)} -path {shell_escape(pattern)} -o -name {shell_escape(pattern)} 2>/dev/null | head -100' + + +def parse_glob_output(text: str) -> list[str]: + """Parse output of a find/glob command into a list of paths.""" + text = text.strip() + if not text: + return [] + return [line for line in text.splitlines() if line] + + +def apply_edit(text: str, old_string: str, new_string: str, path: str, *, replace_all: bool) -> tuple[str, int]: + """Apply a string replacement edit, returning the new text and the number of replacements. + + Raises: + ValueError: If old_string is not found, or appears multiple times + when replace_all is False. + """ + count = text.count(old_string) + + if count == 0: + raise ValueError(f'old_string not found in {path}.') + if not replace_all and count > 1: + raise ValueError(f'old_string found {count} times in {path}. Use replace_all=True or provide more context.') + + if replace_all: + new_text = text.replace(old_string, new_string) + else: + new_text = text.replace(old_string, new_string, 1) + + return new_text, count if replace_all else 1 diff --git a/pydantic_harness/environments/_driver.py b/pydantic_harness/environments/_driver.py new file mode 100644 index 0000000..0bf0489 --- /dev/null +++ b/pydantic_harness/environments/_driver.py @@ -0,0 +1,333 @@ +"""Host-side ABC for driver-based execution environments. + +Provides `DriverBasedEnvironment`, an intermediate abstract base class that +extends `ExecutionEnvironment` with code execution via the NDJSON driver protocol. +Concrete subclasses (Docker, Local) implement `_start_driver` and `_copy_driver`. +""" + +from __future__ import annotations + +import asyncio +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, Any + +from pydantic_harness.toolsets.code_execution._abstract import ( + CodeExecutionTimeout, + CodeRuntimeError, + CodeSyntaxError, + FunctionCall, +) + +from ._base import ExecutionEnvironment + +if TYPE_CHECKING: + from pydantic_harness._python_signature import FunctionSignature, TypeSignature + from pydantic_harness.toolsets.code_execution._abstract import FunctionCallback + + from ._base import ExecutionProcess + + +class DriverTransport(ABC): + """Interface for communicating with a driver process. + + Concrete implementations wrap platform-specific transport types + (`asyncio.subprocess.Process`, SDK handles, WebSocket connections, etc.). + """ + + @abstractmethod + async def read_line(self) -> bytes: + """Read a single newline-terminated line from the driver's stdout.""" + ... + + @abstractmethod + async def write_line(self, data: bytes) -> None: + """Write a line to the driver's stdin (must include trailing newline).""" + ... + + @abstractmethod + async def read_stderr(self) -> bytes: + """Read all available stderr output from the driver.""" + ... + + @abstractmethod + async def kill(self) -> None: + """Terminate the driver process.""" + ... + + +class ExecutionProcessTransport(DriverTransport): + """Adapts an `ExecutionProcess` to the `DriverTransport` interface. + + Provides line-buffered reads on top of the raw `recv()` interface, + which is what the NDJSON protocol requires. + """ + + def __init__(self, process: ExecutionProcess) -> None: + self._proc = process + self._buffer = b'' + + async def read_line(self) -> bytes: + while b'\n' not in self._buffer: + chunk = await self._proc.recv() + if not chunk: + remaining = self._buffer + self._buffer = b'' + return remaining + self._buffer += chunk + line, self._buffer = self._buffer.split(b'\n', 1) + return line + b'\n' + + async def write_line(self, data: bytes) -> None: + await self._proc.send(data) + + async def read_stderr(self) -> bytes: + try: + return await self._proc.recv_stderr(timeout=1.0) + except Exception: + return b'' + + async def kill(self) -> None: + await self._proc.kill() + + +class _ToolError(Exception): + """Wrapper to distinguish tool execution errors from transport/protocol errors.""" + + +class _StdoutSignal(Enum): + """Typed signals from _handle_stdout indicating what happened.""" + + CONTINUE = auto() + + +@dataclass(frozen=True) +class _FinalResult: + """Wraps the final result value from a completed driver execution.""" + + value: Any + + +class DriverBasedEnvironment(ExecutionEnvironment, ABC): + """Environment with code execution via the NDJSON driver protocol. + + Extends `ExecutionEnvironment` with `run_python` that launches a + driver script inside the environment and communicates via NDJSON over + stdin/stdout. The driver handles code compilation, execution, and + proxying of external function calls. + + Subclasses must implement `_copy_driver` (install the driver script into + the environment). The default `_start_driver` uses `create_process` + with an `ExecutionProcessTransport` adapter; override for custom transport. + """ + + execution_timeout: float | None = None + """Optional timeout in seconds for code execution. None means no timeout.""" + + driver_python_path: str = 'python' + """Path to the Python interpreter inside the environment.""" + + driver_script_path: str = '/tmp/pydantic_ai_driver.py' + """Path where the driver script is installed inside the environment.""" + + _driver_copied: bool = False + + # --- Driver protocol --- + + async def _start_driver(self, init_msg: dict[str, Any]) -> DriverTransport: + """Launch the driver process and send the init message. + + The default implementation uses `create_process` with an + `ExecutionProcessTransport` adapter. Override for custom transport + (e.g. asyncio subprocess with the Docker CLI). + + Args: + init_msg: The init message dict to send to the driver. + + Returns: + A DriverTransport for communicating with the driver. + """ + proc = await self.create_process(f'{self.driver_python_path} -u {self.driver_script_path}') + await proc.__aenter__() + transport = ExecutionProcessTransport(proc) + init_line = json.dumps(init_msg).encode() + b'\n' + await transport.write_line(init_line) + return transport + + @abstractmethod + async def _copy_driver(self) -> None: + """Install the driver script into the environment. + + Called once before the first `run_python` invocation. Implementations + should copy the driver script from the host to the environment + (e.g. via `docker exec tee`, file API, or local file reference). + """ + ... + + async def run_python_with_functions( + self, + code: str, + *, + function_callback: FunctionCallback, + functions: dict[str, FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> Any: + """Execute Python code with external functions via the NDJSON driver protocol.""" + if not self._driver_copied: # pragma: no branch + await self._copy_driver() + self._driver_copied = True + + init_msg: dict[str, Any] = { + 'type': 'init', + 'code': code, + 'functions': list(functions) if functions else [], + } + process = await self._start_driver(init_msg) + try: + return await self._run_with_timeout(process, function_callback) + except (CodeSyntaxError, CodeRuntimeError): + raise + except _ToolError as e: + if e.__cause__ is None: # pragma: no cover + raise + raise e.__cause__ + except Exception as e: + raise CodeRuntimeError(f'Driver communication error: {e}') from e + + # --- Protocol implementation --- + + async def _run_with_timeout(self, process: DriverTransport, function_callback: FunctionCallback) -> Any: + """Run the execution loop, applying `execution_timeout` if configured.""" + coro = self._execution_loop(process, function_callback) + if self.execution_timeout is not None: + try: + return await asyncio.wait_for(coro, timeout=self.execution_timeout) + except asyncio.TimeoutError: + await process.kill() + raise CodeExecutionTimeout(f'Code execution timed out after {self.execution_timeout} seconds') + return await coro + + async def _execution_loop(self, process: DriverTransport, function_callback: FunctionCallback) -> Any: + """Run the dual-wait event loop: read driver stdout + dispatch tool tasks.""" + tool_tasks: dict[int, asyncio.Task[Any]] = {} + task_id_to_cid: dict[int, int] = {} + + stdout_task: asyncio.Task[bytes] = asyncio.create_task(process.read_line()) + + try: + while True: + waitables: list[asyncio.Task[Any]] = [stdout_task, *tool_tasks.values()] + done, _ = await asyncio.wait(waitables, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + if task is stdout_task: + result = await self._handle_stdout(task, process, function_callback, tool_tasks, task_id_to_cid) + if isinstance(result, _FinalResult): + return result.value + stdout_task = asyncio.create_task(process.read_line()) + else: + try: + await self._handle_tool_done( + task, + process, + tool_tasks, + task_id_to_cid, + ) + except Exception as e: + raise _ToolError() from e + finally: + await _cancel_all(tool_tasks, stdout_task, process) + + @staticmethod + async def _handle_stdout( + task: asyncio.Task[bytes], + process: DriverTransport, + function_callback: FunctionCallback, + tool_tasks: dict[int, asyncio.Task[Any]], + task_id_to_cid: dict[int, int], + ) -> _StdoutSignal | _FinalResult: + """Handle a completed stdout read task. Returns a signal or the final result.""" + raw = task.result() + if not raw: + stderr = b'' + try: + stderr = await asyncio.wait_for(process.read_stderr(), timeout=1.0) + except Exception: + pass + err_msg = stderr.decode(errors='replace').strip() if stderr else 'Driver process exited unexpectedly' + raise CodeRuntimeError(err_msg) + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + raise CodeRuntimeError(f'Malformed protocol message from driver: {raw[:200]!r}') + + msg_type = msg.get('type') + + if msg_type == 'call': + cid = msg['id'] + fc = FunctionCall( + call_id=str(cid), + function_name=msg['function'], + args=tuple(msg.get('args', ())), + kwargs=msg.get('kwargs', {}), + ) + t = asyncio.ensure_future(function_callback(fc)) + tool_tasks[cid] = t + task_id_to_cid[id(t)] = cid + return _StdoutSignal.CONTINUE + elif msg_type == 'calls_ready': + return _StdoutSignal.CONTINUE + elif msg_type == 'complete': + await process.kill() + return _FinalResult(value=msg.get('result')) + elif msg_type == 'error': + await process.kill() + error_type = msg.get('error_type', 'runtime') + error_msg = msg.get('error', 'Unknown driver error') + if error_type == 'syntax': + raise CodeSyntaxError(error_msg) + raise CodeRuntimeError(error_msg) + + return _StdoutSignal.CONTINUE + + @staticmethod + async def _handle_tool_done( + task: asyncio.Task[Any], + process: DriverTransport, + tool_tasks: dict[int, asyncio.Task[Any]], + task_id_to_cid: dict[int, int], + ) -> None: + """Handle a completed tool task: send result back to the driver.""" + cid = task_id_to_cid.pop(id(task)) + del tool_tasks[cid] + + result = task.result() + result_msg = json.dumps({'type': 'result', 'id': cid, 'result': result}) + '\n' + await process.write_line(result_msg.encode()) + + +async def _cancel_task(task: asyncio.Task[Any]) -> None: + """Cancel a task and suppress CancelledError.""" + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +async def _cancel_all( + tool_tasks: dict[int, asyncio.Task[Any]], + stdout_task: asyncio.Task[Any], + process: DriverTransport, +) -> None: + """Cancel all pending tasks and kill the driver process.""" + all_tasks = [*tool_tasks.values(), stdout_task] + for t in all_tasks: + await _cancel_task(t) + try: + await process.kill() + except Exception: + pass diff --git a/pydantic_harness/environments/local.py b/pydantic_harness/environments/local.py new file mode 100644 index 0000000..f68c61f --- /dev/null +++ b/pydantic_harness/environments/local.py @@ -0,0 +1,350 @@ +"""Local subprocess-based execution environment for development and testing. + +Runs commands directly on the host machine within a specified root directory. +**No isolation** — use `DockerEnvironment` for untrusted code. +""" + +from __future__ import annotations + +import re +import subprocess +from pathlib import Path +from typing import Any, Literal + +import anyio +import anyio.abc +from typing_extensions import Self + +from ._base import ( + IMAGE_EXTENSIONS, + MAX_OUTPUT_CHARS, + Capability, + ExecutionProcess, + ExecutionResult, + FileInfo, + apply_edit, + collect_grep_matches, + format_lines, +) +from ._driver import DriverBasedEnvironment + + +class LocalEnvironmentProcess(ExecutionProcess): + """Interactive process backed by `anyio.abc.Process`.""" + + def __init__(self, proc: anyio.abc.Process) -> None: + self._proc = proc + + async def send(self, data: bytes) -> None: + stdin = self._proc.stdin + if stdin is None: + raise RuntimeError('Process stdin is not available.') + await stdin.send(data) + + async def recv(self, timeout: float | None = None) -> bytes: + stdout = self._proc.stdout + if stdout is None: + raise RuntimeError('Process stdout is not available.') + try: + if timeout is not None: + with anyio.fail_after(timeout): + return await stdout.receive(8192) + return await stdout.receive(8192) + except anyio.EndOfStream: + return b'' + + async def recv_stderr(self, timeout: float | None = None) -> bytes: + stderr = self._proc.stderr + if stderr is None: + raise RuntimeError('Process stderr is not available.') + try: + if timeout is not None: + with anyio.fail_after(timeout): + return await stderr.receive(8192) + return await stderr.receive(8192) + except anyio.EndOfStream: + return b'' + + @property + def returncode(self) -> int | None: + return self._proc.returncode + + async def wait(self, timeout: float | None = None) -> int: + if timeout is not None: + with anyio.fail_after(timeout): + return await self._proc.wait() + return await self._proc.wait() + + async def kill(self) -> None: + try: + self._proc.kill() + except ProcessLookupError: + pass + await self._proc.aclose() + _close_subprocess_transport(self._proc) + + +def _close_subprocess_transport(proc: anyio.abc.Process) -> None: + """Close the underlying asyncio subprocess transport to prevent ResourceWarning on Python 3.10. + + On Python 3.10, asyncio subprocess transports are not closed by + `Process.wait()` or `Process.aclose()` and their `__del__` + emits `ResourceWarning: unclosed transport`. Python 3.11+ fixed + this, but we still support 3.10. + """ + inner = getattr(proc, '_process', None) # anyio wraps asyncio.subprocess.Process + transport = getattr(inner, '_transport', None) + if transport is not None: # pragma: no branch + transport.close() + + +class LocalEnvironment(DriverBasedEnvironment): + """Local subprocess-based execution environment for development and testing. + + Runs commands directly on the host machine within a specified root + directory. Provides no isolation — use `DockerEnvironment` for untrusted code. + + Usage: + ```python {test="skip" lint="skip"} + async with LocalEnvironment(root_dir='/tmp/workspace') as env: + result = await env.shell('python script.py') + print(result.output) + ``` + """ + + driver_script_path: str = str(Path(__file__).parents[1] / 'toolsets' / 'code_execution' / '_driver.py') + + def __init__( + self, + root_dir: str | Path = '.', + *, + env_vars: dict[str, str] | None = None, + inherit_env: bool = True, + ) -> None: + """Create a local execution environment. + + Args: + root_dir: The working directory for all operations. + Defaults to the current directory. + env_vars: Baseline environment variables for all commands. + inherit_env: Whether to inherit the host's environment variables. + When True (default), `env_vars` and per-call `env` are merged + on top of `os.environ`. When False, only `env_vars` and per-call + `env` are used (useful for reproducibility and testing). + """ + self._root_dir = Path(root_dir).resolve() + self._env_vars = env_vars or {} + self._inherit_env = inherit_env + + @property + def capabilities(self) -> frozenset[Capability]: + return frozenset({'ls', 'shell', 'read_file', 'write_file', 'replace_str', 'glob', 'grep', 'run_python'}) + + async def __aenter__(self) -> Self: + self._root_dir.mkdir(parents=True, exist_ok=True) + return self + + async def __aexit__(self, *_args: Any) -> None: + pass + + def _resolve_path(self, path: str) -> Path: + """Resolve a path relative to root_dir, preventing traversal.""" + resolved = (self._root_dir / path).resolve() + if not resolved.is_relative_to(self._root_dir): + raise PermissionError(f'Path {path!r} resolves outside the environment root.') + return resolved + + def _build_env(self, env: dict[str, str] | None) -> dict[str, str] | None: + """Merge baseline env vars with per-call overrides.""" + if not self._env_vars and not env and self._inherit_env: + return None # subprocess inherits naturally + import os + + merged = {**os.environ} if self._inherit_env else {} + merged.update(self._env_vars) + if env: + merged.update(env) + return merged + + async def create_process( + self, + command: str, + *, + env: dict[str, str] | None = None, + ) -> ExecutionProcess: + proc = await anyio.open_process( + command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self._root_dir, + env=self._build_env(env), + ) + return LocalEnvironmentProcess(proc) + + async def shell( + self, + command: str, + *, + timeout: float | None = 120, + env: dict[str, str] | None = None, + ) -> ExecutionResult: + """Execute a command using subprocess for simplicity and reliability.""" + proc = await anyio.open_process( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=self._root_dir, + env=self._build_env(env), + ) + try: + assert proc.stdout is not None + chunks: list[bytes] = [] + if timeout is not None: + with anyio.fail_after(timeout): + async for chunk in proc.stdout: + chunks.append(chunk) + await proc.wait() + else: + async for chunk in proc.stdout: + chunks.append(chunk) + await proc.wait() + except TimeoutError: + proc.kill() + with anyio.CancelScope(shield=True): + await proc.wait() + _close_subprocess_transport(proc) + return ExecutionResult(output='[Command timed out]', exit_code=-1) + + _close_subprocess_transport(proc) + stdout = b''.join(chunks) + output = stdout.decode('utf-8', errors='replace') + truncated = len(output) > MAX_OUTPUT_CHARS + if truncated: + output = output[:MAX_OUTPUT_CHARS] + return ExecutionResult( + output=output, + exit_code=proc.returncode if proc.returncode is not None else 0, + truncated=truncated, + ) + + async def read_file(self, path: str, *, offset: int = 0, limit: int = 2000) -> str | bytes: + resolved = self._resolve_path(path) + if not resolved.is_file(): + if resolved.is_dir(): + raise FileNotFoundError(f"'{path}' is a directory, not a file.") + raise FileNotFoundError(f'File not found: {path}') + + if resolved.suffix.lower() in IMAGE_EXTENSIONS: + return resolved.read_bytes() + + raw = resolved.read_bytes() + try: + text = raw.decode('utf-8') + except UnicodeDecodeError: + return raw + return format_lines(text, offset, limit) + + async def write_file(self, path: str, content: str | bytes) -> None: + resolved = self._resolve_path(path) + resolved.parent.mkdir(parents=True, exist_ok=True) + if isinstance(content, bytes): + resolved.write_bytes(content) + else: + resolved.write_text(content, encoding='utf-8') + + async def replace_str( + self, + path: str, + old: str, + new: str, + *, + replace_all: bool = False, + ) -> int: + resolved = self._resolve_path(path) + if not resolved.is_file(): + raise FileNotFoundError(f'File not found: {path}') + + text = resolved.read_text(encoding='utf-8') + new_text, count = apply_edit(text, old, new, path, replace_all=replace_all) + resolved.write_text(new_text, encoding='utf-8') + return count + + async def ls(self, path: str = '.') -> list[FileInfo]: + resolved = self._resolve_path(path) + if not resolved.is_dir(): + raise NotADirectoryError(f'Not a directory: {path}') + + entries: list[FileInfo] = [] + for entry in sorted(resolved.iterdir()): + try: + stat = entry.stat() + entries.append( + FileInfo( + name=entry.name, + path=str(entry.relative_to(self._root_dir)), + is_dir=entry.is_dir(), + size=stat.st_size if not entry.is_dir() else None, + ) + ) + except OSError: # pragma: no cover + continue + return entries + + async def glob(self, pattern: str, *, path: str = '.') -> list[str]: + resolved = self._resolve_path(path) + matches: list[str] = [] + for match in sorted(resolved.glob(pattern)): + try: + rel = str(match.relative_to(self._root_dir)) + matches.append(rel) + except ValueError: # pragma: no cover + continue + return matches + + async def grep( + self, + pattern: str, + *, + path: str | None = None, + glob_pattern: str | None = None, + output_mode: Literal['content', 'files_with_matches', 'count'] = 'content', + ) -> str: + search_dir = self._resolve_path(path or '.') + compiled = re.compile(pattern) + + if search_dir.is_file(): + files = [search_dir] + elif glob_pattern: + files = sorted(search_dir.glob(glob_pattern)) + else: + files = sorted(search_dir.rglob('*')) + + results: list[str] = [] + for file_path in files: + if not file_path.is_file(): + continue + # Skip hidden files/directories (e.g. .git/, .venv/) + if any(part.startswith('.') for part in file_path.relative_to(self._root_dir).parts): + continue + try: + raw = file_path.read_bytes() + except OSError: + continue + + # Skip binary files (null byte in first 8KB) + if b'\x00' in raw[:8192]: + continue + + text = raw.decode('utf-8', errors='replace') + rel_path = str(file_path.relative_to(self._root_dir)) + collect_grep_matches(rel_path, text, compiled, output_mode, results) + + if len(results) > 1000: + results.append('[... truncated at 1000 matches]') + break + + return '\n'.join(results) + + async def _copy_driver(self) -> None: + """No-op: the driver script is accessed directly from the package path.""" diff --git a/pydantic_harness/environments/memory.py b/pydantic_harness/environments/memory.py new file mode 100644 index 0000000..5e406b7 --- /dev/null +++ b/pydantic_harness/environments/memory.py @@ -0,0 +1,267 @@ +"""In-memory execution environment for testing. + +All file operations use an in-memory dictionary. Shell commands are handled +by an optional callback — if not provided, `shell()` raises `RuntimeError`. +""" + +from __future__ import annotations + +import fnmatch +import posixpath +import re +from collections.abc import Callable +from typing import TYPE_CHECKING, Literal + +from ._base import ( + IMAGE_EXTENSIONS, + ExecutionEnvironment, + ExecutionResult, + FileInfo, + apply_edit, + collect_grep_matches, + format_lines, + glob_match, +) + +if TYPE_CHECKING: + from ._base import Capability + + +class MemoryEnvironment(ExecutionEnvironment): + """In-memory execution environment for testing. + + File operations use an in-memory dictionary, making tests fast and + isolated with no filesystem access. Shell commands can optionally be + handled by a user-provided callback. + + This is the testing counterpart to `LocalEnvironment`, analogous to + how `TestModel` and `FunctionModel` relate to real model classes. + + Usage: + ```python {test="skip" lint="skip"} + from pydantic_ai.environments.memory import MemoryEnvironment + + env = MemoryEnvironment(files={'main.py': 'print("hello")'}) + async with env: + content = await env.read_file('main.py') + assert 'hello' in content + ``` + """ + + def __init__( + self, + files: dict[str, str | bytes] | None = None, + *, + command_handler: Callable[[str], ExecutionResult] | None = None, + ) -> None: + """Create an in-memory execution environment. + + Args: + files: Initial files to populate the environment with. + Keys are file paths, values are file contents (str or bytes). + command_handler: Optional callback for `shell()` calls. + Receives the command string and returns an `ExecutionResult`. + If not provided, `shell()` raises `RuntimeError`. + """ + self._files: dict[str, str | bytes] = {} + if files: + for path, content in files.items(): + self._files[self._normalize(path)] = content + self._command_handler = command_handler + + @property + def capabilities(self) -> frozenset[Capability]: + return frozenset({'ls', 'read_file', 'write_file', 'replace_str', 'glob', 'grep'}) + + @staticmethod + def _normalize(path: str) -> str: + """Normalize a path for consistent storage.""" + normalized = posixpath.normpath(path) + # Strip leading './' or '/' + if normalized.startswith('./'): # pragma: no cover + normalized = normalized[2:] + elif normalized.startswith('/'): + normalized = normalized[1:] + return normalized + + async def shell( + self, + command: str, + *, + timeout: float | None = 120, + env: dict[str, str] | None = None, + ) -> ExecutionResult: + """Execute a command using the configured handler. + + Args: + command: The shell command to execute. + timeout: Ignored for MemoryEnvironment. + env: Ignored for MemoryEnvironment. + + Returns: + The result from the command handler. + + Raises: + RuntimeError: If no command_handler was provided. + """ + if self._command_handler is None: + raise RuntimeError( + 'MemoryEnvironment has no command_handler configured. ' + 'Pass command_handler= to the constructor to handle shell() calls.' + ) + return self._command_handler(command) + + async def read_file(self, path: str, *, offset: int = 0, limit: int = 2000) -> str | bytes: + normalized = self._normalize(path) + + # Check if path is a "directory" (any file starts with path/) + if any(k.startswith(normalized + '/') for k in self._files): + if normalized not in self._files: + raise FileNotFoundError(f"'{path}' is a directory, not a file.") + + if normalized not in self._files: + raise FileNotFoundError(f'File not found: {path}') + + content = self._files[normalized] + + # Return raw bytes for image files + ext = posixpath.splitext(normalized)[1].lower() + if ext in IMAGE_EXTENSIONS: + if isinstance(content, bytes): + return content + return content.encode('utf-8') + + # Text mode + if isinstance(content, bytes): + try: + text = content.decode('utf-8') + except UnicodeDecodeError: + return content + else: + text = content + + return format_lines(text, offset, limit) + + async def write_file(self, path: str, content: str | bytes) -> None: + self._files[self._normalize(path)] = content + + async def replace_str( + self, + path: str, + old: str, + new: str, + *, + replace_all: bool = False, + ) -> int: + normalized = self._normalize(path) + if normalized not in self._files: + raise FileNotFoundError(f'File not found: {path}') + + content = self._files[normalized] + text = content.decode('utf-8') if isinstance(content, bytes) else content + new_text, count = apply_edit(text, old, new, path, replace_all=replace_all) + self._files[normalized] = new_text + return count + + async def ls(self, path: str = '.') -> list[FileInfo]: + normalized = self._normalize(path) + + # Collect direct children + entries: dict[str, FileInfo] = {} + for file_path in sorted(self._files): + if normalized == '.': + rel = file_path + elif file_path.startswith(normalized + '/'): + rel = file_path[len(normalized) + 1 :] + else: + continue + + # Get the first component (direct child) + parts = rel.split('/', 1) + name = parts[0] + if name in entries: + continue + + is_dir = len(parts) > 1 + if is_dir: + entries[name] = FileInfo( + name=name, + path=f'{normalized}/{name}' if normalized != '.' else name, + is_dir=True, + ) + else: + content = self._files[file_path] + size = len(content) if isinstance(content, bytes) else len(content.encode('utf-8')) + entries[name] = FileInfo( + name=name, + path=f'{normalized}/{name}' if normalized != '.' else name, + is_dir=False, + size=size, + ) + + if not entries and normalized != '.': + raise NotADirectoryError(f'Not a directory: {path}') + + return list(entries.values()) + + async def glob(self, pattern: str, *, path: str = '.') -> list[str]: + normalized = self._normalize(path) + matches: list[str] = [] + for file_path in sorted(self._files): + if normalized != '.': + if not file_path.startswith(normalized + '/'): + continue + rel = file_path[len(normalized) + 1 :] + else: + rel = file_path + + if glob_match(rel, pattern): + matches.append(file_path) + + return matches + + async def grep( + self, + pattern: str, + *, + path: str | None = None, + glob_pattern: str | None = None, + output_mode: Literal['content', 'files_with_matches', 'count'] = 'content', + ) -> str: + normalized = self._normalize(path or '.') + compiled = re.compile(pattern) + + results: list[str] = [] + for file_path in sorted(self._files): + # Path filtering + if normalized != '.': + if normalized == file_path: + pass # exact file match + elif not file_path.startswith(normalized + '/'): + continue + + # Glob filtering + if glob_pattern and not fnmatch.fnmatch(posixpath.basename(file_path), glob_pattern): + continue + + # Skip hidden files + if any(part.startswith('.') for part in file_path.split('/')): + continue + + content = self._files[file_path] + + # Skip binary files + if isinstance(content, bytes): + if b'\x00' in content[:8192]: + continue + text = content.decode('utf-8', errors='replace') + else: + text = content + + collect_grep_matches(file_path, text, compiled, output_mode, results) + + if len(results) > 1000: + results.append('[... truncated at 1000 matches]') + break + + return '\n'.join(results) diff --git a/pydantic_harness/environments/monty.py b/pydantic_harness/environments/monty.py new file mode 100644 index 0000000..bc6448b --- /dev/null +++ b/pydantic_harness/environments/monty.py @@ -0,0 +1,311 @@ +"""Monty sandboxed interpreter environment for code execution. + +Requires the `pydantic-monty` package: `pip install pydantic-harness[monty]` +""" + +from __future__ import annotations + +import asyncio +import textwrap +from typing import TYPE_CHECKING, Any + +from typing_extensions import Self, assert_never + +from pydantic_harness.toolsets.code_execution._abstract import ( + CodeExecutionTimeout, + CodeRuntimeError, + CodeSyntaxError, + CodeTypingError, + FunctionCall, +) + +from ._base import ExecutionEnvironment + +if TYPE_CHECKING: + from pydantic_harness._python_signature import FunctionSignature, TypeSignature + from pydantic_harness.toolsets.code_execution._abstract import FunctionCallback + + from ._base import Capability + +try: + from pydantic_monty import ( + ExternalReturnValue, + FunctionSnapshot, + FutureSnapshot, + Monty, + MontyComplete, + MontyRepl, + MontyRuntimeError, + MontySyntaxError, + MontyTypingError, + NameLookupSnapshot, + ResourceLimits, + ) +except ImportError as _import_error: + raise ImportError( + 'Please install `pydantic-monty` to use MontyEnvironment, ' + 'you can use the `monty` optional group — `pip install "pydantic-harness[monty]"`' + ) from _import_error + + +class MontyEnvironment(ExecutionEnvironment): + """Execution environment using the Monty sandboxed REPL interpreter. + + Monty provides sandboxed execution with state persistence across calls. + Code is executed incrementally via `MontyRepl.feed_start()`, which pauses + at every external function call, returning a snapshot that can be resumed + once the host has computed the function's return value. + + Print output from executed code is captured and returned alongside the + result value. + + This environment only supports code execution (`run_code` capability). + It does not provide shell, file, or search operations. + """ + + execution_timeout: float | None = None + """Optional timeout in seconds for code execution. None means no timeout.""" + + _repl: MontyRepl | None = None + + @property + def capabilities(self) -> frozenset[Capability]: + return frozenset({'run_python', 'run_python_with_functions'}) + + def instructions(self, capability: Capability) -> str | None: + if capability in ('run_python', 'run_python_with_functions'): + return textwrap.dedent( + """ + The runtime uses a restricted Python subset: + - you cannot use the standard library except builtin functions and the following modules: `sys`, `typing`, `asyncio`, `math`, `re` + - this means `collections`, `json`, `datetime`, `itertools`, `functools`, etc. are NOT available — use plain dicts, lists, and builtins instead + - you cannot use third party libraries + - you cannot define classes + - chained subscript assignment like `x[a][b] = val` is NOT supported — read into a local variable, modify it, then assign back: `inner = x[a]; inner[b] = val; x[a] = inner` + + State persists across calls — variables and functions defined in previous calls are available in subsequent calls. + + The last expression evaluated is the return value. + + Parallelism: use `asyncio.gather` to fire multiple calls at the same time instead of awaiting each one sequentially: + + # GOOD — parallel (all calls fire at once): + results = await asyncio.gather( + get_data(id=1), + get_data(id=2), + get_data(id=3), + ) + + # BAD — sequential (each call waits before the next starts): + r1 = await get_data(id=1) + r2 = await get_data(id=2) + r3 = await get_data(id=3) + """ + ) + return None # pragma: no cover + + async def __aenter__(self) -> Self: + self._ensure_repl() + return self + + async def __aexit__(self, *args: object) -> None: + self._repl = None + + def reset(self) -> None: + """Discard the current REPL session so the next execution starts fresh.""" + self._repl = None + + def type_check( + self, + code: str, + *, + signatures: list[FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> None: + """Type-check code using a stateless Monty instance. + + This is only sound when the REPL has no accumulated state (i.e. after a + reset), because the stateless type checker has no visibility into + variables defined in prior REPL snippets. + + Args: + code: The Python code to type-check. + signatures: Function signatures to include as type stubs. + referenced_types: Type definitions referenced by the signatures. + + Raises: + CodeTypingError: If type errors are found. + CodeSyntaxError: If the code has a syntax error. + """ + prefix = self._build_type_check_prefix(signatures or [], referenced_types or []) + try: + Monty(code, type_check=True, type_check_stubs=prefix) + except MontySyntaxError as e: + raise CodeSyntaxError(e.display()) from e + except MontyTypingError as e: + raise CodeTypingError(e.display()) from e + + # TODO: Concurrent agent runs sharing a MontyEnvironment will fail because + # MontyRepl has an internal mutex — only one snippet can execute at a time. + # Once PR #4688 (for_run/for_run_step lifecycle hooks) lands, + # CodeExecutionToolset.for_run() should return a new instance with a fresh + # MontyEnvironment so each agent run gets its own REPL with isolated state. + + def _ensure_repl(self) -> MontyRepl: + """Return the active REPL, creating it lazily if needed.""" + if self._repl is None: + limits = ( + ResourceLimits(max_duration_secs=self.execution_timeout) if self.execution_timeout is not None else None + ) + self._repl = MontyRepl(limits=limits) + self._repl.feed_start('import asyncio') + return self._repl + + async def run_python(self, code: str) -> Any: + """Execute code in the Monty REPL sandbox without external functions.""" + prints: list[str] = [] + try: + monty_state = self._ensure_repl().feed_start(code, print_callback=lambda _stream, text: prints.append(text)) + # Handle NameLookupSnapshot by resuming without a value (raises NameError in sandbox) + while isinstance(monty_state, NameLookupSnapshot): + monty_state = monty_state.resume() + if not isinstance(monty_state, MontyComplete): + raise CodeRuntimeError( + 'Unexpected external function call in code without functions.' + ) # pragma: no cover + except MontySyntaxError as e: + raise CodeSyntaxError(self._prepend_prints(e.display(), prints)) from e + except MontyRuntimeError as e: + self._raise_if_timeout(e, prints) + raise CodeRuntimeError(self._prepend_prints(e.display(), prints)) from e + + return self._build_result(monty_state.output, prints) + + async def run_python_with_functions( + self, + code: str, + *, + function_callback: FunctionCallback, + functions: dict[str, FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> Any: + """Execute code in the Monty REPL sandbox with external function support.""" + if functions is None: + functions = {} + + prints: list[str] = [] + try: + monty_state = self._ensure_repl().feed_start(code, print_callback=lambda _stream, text: prints.append(text)) + monty_state = await self._execution_loop(monty_state, function_callback, functions=functions) + except MontySyntaxError as e: + raise CodeSyntaxError(self._prepend_prints(e.display(), prints)) from e + except MontyRuntimeError as e: + self._raise_if_timeout(e, prints) + raise CodeRuntimeError(self._prepend_prints(e.display(), prints)) from e + + return self._build_result(monty_state.output, prints) + + @staticmethod + def _build_result(output: Any, prints: list[str]) -> Any: + """Combine the expression result with any captured print output.""" + printed_text = ''.join(prints).rstrip('\n') + if not printed_text: + return output + + if output is None: + return printed_text + + return {'stdout': printed_text, 'result': output} + + @staticmethod + def _prepend_prints(error_message: str, prints: list[str]) -> str: + """Prepend any captured print output to an error message.""" + printed_text = ''.join(prints).rstrip('\n') + if not printed_text: + return error_message + return f'[stdout before error]\n{printed_text}\n[/stdout before error]\n{error_message}' + + def _build_type_check_prefix( + self, signatures: list[FunctionSignature], referenced_types: list[TypeSignature] + ) -> str: + """Build the prefix code used for Monty type checking.""" + parts = ['import asyncio\nfrom typing import Any, TypedDict, NotRequired, Literal'] + parts.extend(str(t) for t in referenced_types) + parts.extend(sig.render('raise NotImplementedError()') for sig in signatures) + + return '\n\n'.join(parts) + + def _raise_if_timeout(self, e: MontyRuntimeError, prints: list[str]) -> None: + """Raise CodeExecutionTimeout if the MontyRuntimeError is a time limit violation.""" + if self.execution_timeout is not None and 'time limit exceeded' in e.display(): + msg = f'Code execution timed out after {self.execution_timeout} seconds' + raise CodeExecutionTimeout(self._prepend_prints(msg, prints)) from e + + @staticmethod + async def _execution_loop( + monty_state: MontyComplete | FutureSnapshot | FunctionSnapshot | NameLookupSnapshot, + function_callback: FunctionCallback, + *, + functions: dict[str, FunctionSignature], + ) -> MontyComplete: + tasks: dict[int, asyncio.Task[Any]] = {} + try: + while not isinstance(monty_state, MontyComplete): + if isinstance(monty_state, NameLookupSnapshot): + monty_state = monty_state.resume() + continue + if isinstance(monty_state, FunctionSnapshot): + call = FunctionCall( + call_id=f'monty_{monty_state.call_id}', + function_name=monty_state.function_name, + args=monty_state.args, + kwargs=monty_state.kwargs, + ) + sig = functions.get(monty_state.function_name) + + if sig and not sig.is_async: + # Sequential: drain pending async tasks, then call synchronously + if tasks: + await asyncio.gather(*tasks.values()) + result = await function_callback(call) + monty_state = monty_state.resume(return_value=result) + else: + # Async: fire and defer (existing behavior) + tasks[monty_state.call_id] = asyncio.ensure_future(function_callback(call)) + monty_state = monty_state.resume(future=...) + elif isinstance(monty_state, FutureSnapshot): + pending_call_ids = monty_state.pending_call_ids + if not pending_call_ids: # pragma: no cover + monty_state = monty_state.resume(results={}) + continue + + try: + pending_tasks = [tasks[call_id] for call_id in pending_call_ids] + except KeyError as e: # pragma: no cover + raise CodeRuntimeError( + f'Monty expects results for call IDs {pending_call_ids} but no tasks exist' + ) from e + + try: + task_results = await asyncio.gather(*pending_tasks) + except BaseException: + for t in pending_tasks: + t.cancel() + raise + finally: + for call_id in pending_call_ids: + del tasks[call_id] + + monty_state = monty_state.resume( + results={ + call_id: ExternalReturnValue(return_value=result) + for call_id, result in zip(pending_call_ids, task_results) + } + ) + else: + assert_never(monty_state) + finally: + for t in tasks.values(): + t.cancel() + + return monty_state diff --git a/pydantic_harness/py.typed b/pydantic_harness/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pydantic_harness/toolsets/__init__.py b/pydantic_harness/toolsets/__init__.py new file mode 100644 index 0000000..2d55d6f --- /dev/null +++ b/pydantic_harness/toolsets/__init__.py @@ -0,0 +1,6 @@ +"""Toolsets for pydantic-harness: code execution and execution environment toolsets.""" + +from .code_execution import CodeExecutionToolset +from .execution_environment import ExecutionEnvironmentToolset + +__all__ = ('CodeExecutionToolset', 'ExecutionEnvironmentToolset') diff --git a/pydantic_harness/toolsets/code_execution/__init__.py b/pydantic_harness/toolsets/code_execution/__init__.py new file mode 100644 index 0000000..3608357 --- /dev/null +++ b/pydantic_harness/toolsets/code_execution/__init__.py @@ -0,0 +1,487 @@ +"""Code execution toolset that optionally wraps tools as Python functions callable from generated code.""" + +from __future__ import annotations + +import copy +import keyword +import re +from collections.abc import Callable +from dataclasses import KW_ONLY, dataclass, replace +from typing import Any, Literal, TypeAlias, cast + +from pydantic import TypeAdapter, ValidationError +from typing_extensions import Self, TypedDict, assert_never + +from pydantic_ai.messages import tool_return_ta # private API +from pydantic_ai import exceptions +from pydantic_ai._run_context import AgentDepsT, RunContext # private API +from pydantic_ai._tool_manager import ToolManager, _parallel_execution_mode_ctx_var # pyright: ignore[reportPrivateUsage] # private API +from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry +from pydantic_ai.messages import ToolCallPart +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.abstract import AbstractToolset, SchemaValidatorProt, ToolsetTool + +from pydantic_harness._python_signature import ( + FunctionSignature, + TypeSignature, + collect_unique_referenced_types, + dedup_referenced_types, + schema_to_signature, +) +from pydantic_harness.environments._base import ExecutionEnvironment +from ._abstract import ( + CodeExecutionError, + CodeExecutionTimeout, + CodeRuntimeError, + CodeSyntaxError, + CodeTypingError, + FunctionCall, + FunctionCallback, +) + +__all__ = ( + 'CodeExecutionError', + 'CodeExecutionTimeout', + 'CodeExecutionToolset', + 'CodeRuntimeError', + 'CodeSyntaxError', + 'CodeTypingError', + 'DescriptionFunc', + 'FunctionCall', + 'FunctionCallback', + 'FunctionSignature', + 'TypeSignature', + 'build_default_description', +) + + +EnvironmentName = Literal['monty'] + + +def get_environment(name: EnvironmentName) -> ExecutionEnvironment: + """Create an execution environment by name. + + Args: + name: The environment name (`'monty'`). + + Returns: + A new `ExecutionEnvironment` instance. + """ + if name == 'monty': + from pydantic_harness.environments.monty import MontyEnvironment + + return MontyEnvironment() + else: + assert_never(name) + + +# The `restart` parameter follows the pattern established by Anthropic's bash tool +# (`bash_20250124`), where `restart: true` clears session state. Models are already +# trained to understand this interaction pattern — it's a well-understood primitive +# for "discard accumulated state and start fresh". + + +class _CodeToolArguments(TypedDict, total=False): + code: str + restart: bool + + +_CODE_ADAPTER = TypeAdapter(_CodeToolArguments) +_CODE_VALIDATOR = _CODE_ADAPTER.validator +_CODE_JSON_SCHEMA = _CODE_ADAPTER.json_schema() +_TOOL_NAME = 'run_code' + +_BASE_PROMPT = """ +Use this tool to run Python code. + +Execution model: +- This is a REPL session — state persists across calls. Variables, functions, and imports defined in previous calls are available in subsequent calls. +- If a previous call failed, the state from earlier *successful* calls is still intact — you only need to fix the failed snippet. +- You can create new functions for convenience. +- This tool is for running code — don't use it just to format or print your final analysis. + +Session management: +- Set `restart: true` to clear all accumulated state and start a fresh session. You can combine it with `code` to reset and run in one call, or use it alone to just reset. +- Use restart when your session state is corrupted or you want a completely clean slate. +""" + +_TOOLS_PROMPT = """ +Use this tool to run Python code that can call other tools as functions. + +You can use it to: +- filter tool return data to save context, +- perform complex operations that would take many model calls using standard tool calling, or +- pass the result of one tool to another without it entering your context window. + +Execution model: +- This is a REPL session — state persists across calls. Variables, functions, and imports defined in previous calls are available in subsequent calls. You can split work across multiple calls and build on earlier results. +- If a previous call failed, the state from earlier *successful* calls is still intact — you only need to fix the failed snippet, not rewrite everything from scratch. +- You can create new functions for convenience. +- This tool is for calling and chaining tools programmatically — don't use it just to format or print your final analysis. Write your report as regular text in your response. + +Session management: +- Set `restart: true` to clear all accumulated state and start a fresh session. You can combine it with `code` to reset and run in one call, or use it alone to just reset. +- Use restart when your session state is corrupted or you want a completely clean slate. +""" + + +DescriptionFunc: TypeAlias = Callable[[list[FunctionSignature], list[TypeSignature], str | None], str] +"""Callback type for building the code execution tool description. + +Receives the function signatures, their referenced types, and optional +environment-specific instructions. Returns the complete tool description string. +""" + + +def build_default_description( + signatures: list[FunctionSignature], + referenced_types: list[TypeSignature], + environment_instructions: str | None, + *, + description: str | None = None, +) -> str: + """Build the default code execution tool description with the given tool signatures. + + This is the default description builder used by CodeExecutionToolset. Users can provide + their own callback via the `description` parameter, or pass a string to customize + just the preamble text while keeping the default structure. + + Args: + signatures: List of Python function signatures for available tools. + referenced_types: Unique type definitions referenced by the signatures. + environment_instructions: Environment-specific text to include in the description + (from `environment.instructions('run_python')`). Inserted verbatim if non-empty. + description: Custom preamble text to use instead of the built-in default. + + Returns: + The complete description string with preamble, available types, and function signatures. + """ + if description is None: + description = _TOOLS_PROMPT if signatures else _BASE_PROMPT + + parts = [description] + + if environment_instructions: + parts.append(environment_instructions) + + if signatures: + parts.append('```python') + + if referenced_types: + parts.append('# Available types:') + parts.extend(str(t) for t in referenced_types) + + parts.append('# Available functions:') + parts.extend(str(sig) for sig in signatures) + + parts.append('```') + + return '\n\n'.join(parts) + + +@dataclass(kw_only=True) +class _CodeExecutionTool(ToolsetTool[AgentDepsT]): + signatures: list[FunctionSignature] + referenced_types: list[TypeSignature] + name_map: dict[str, str] + tools: dict[str, ToolsetTool[AgentDepsT]] + + +@dataclass(init=False) +class CodeExecutionToolset(AbstractToolset[AgentDepsT]): + """A toolset that executes Python code, optionally with access to wrapped tools as callable functions. + + When a `toolset` is provided, its tools are exposed as callable Python functions in the code + execution context. When no `toolset` is provided, it acts as a pure code execution environment. + + Args: + environment: The code execution environment. Can be an environment instance or a string + shorthand (`'monty'`). Defaults to `'monty'`. + toolset: Optional underlying toolset to wrap. When provided, its tools are exposed as + callable Python functions in the code execution context. + description: Custom tool description. Can be a string (used as the preamble text + with the default structure) or a `DescriptionFunc` callback for full control. + Defaults to `build_default_description`. + max_retries: Maximum number of retries for code execution errors (type/syntax/runtime). + Defaults to 3. Increase for complex code generation tasks or less capable models. + """ + + environment: ExecutionEnvironment + + _: KW_ONLY + + toolset: AbstractToolset[AgentDepsT] | None + description: str | DescriptionFunc + max_retries: int = 3 + + def __init__( + self, + environment: ExecutionEnvironment | EnvironmentName = 'monty', + *, + toolset: AbstractToolset[AgentDepsT] | None = None, + description: str | DescriptionFunc = build_default_description, + max_retries: int = 3, + ) -> None: + if isinstance(environment, str): + environment = get_environment(environment) + if toolset is not None and 'run_python_with_functions' not in environment.capabilities: + raise TypeError( + f'{type(environment).__name__} does not support external functions. ' + 'Cannot wrap tools for code execution.' + ) + self.environment = environment + self.toolset = toolset + self.description = description + self.max_retries = max_retries + + @property + def id(self) -> str | None: + return None # pragma: no cover + + @property + def label(self) -> str: # pragma: no cover + if self.toolset is not None: + return f'CodeExecutionToolset({self.toolset.label})' + return 'CodeExecutionToolset' + + async def __aenter__(self) -> Self: + await self.environment.__aenter__() + try: + if self.toolset is not None: + await self.toolset.__aenter__() + except BaseException: + await self.environment.__aexit__(None, None, None) + raise + return self + + async def __aexit__(self, *args: object) -> bool | None: + try: + if self.toolset is not None: + return await self.toolset.__aexit__(*args) + return None + finally: + await self.environment.__aexit__(*args) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: # pragma: no cover + if self.toolset is not None: + self.toolset.apply(visitor) + else: + visitor(self) + + def visit_and_replace( + self, visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]] + ) -> AbstractToolset[AgentDepsT]: + if self.toolset is not None: + return replace(self, toolset=self.toolset.visit_and_replace(visitor)) + return visitor(self) + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + wrapped_tools: dict[str, ToolsetTool[AgentDepsT]] = {} + if self.toolset is not None: + wrapped_tools = await self.toolset.get_tools(ctx) + + deferred_tools = [name for name, tool in wrapped_tools.items() if tool.tool_def.defer] + if deferred_tools: + raise exceptions.UserError( + 'Tool approval and deferral are not yet supported in code execution mode. ' + 'Ensure wrapped tools do not use approval or deferral when used with CodeExecutionToolset.' + ) + + if wrapped_tools: + # Build sanitized name map: {sanitized_name: original_name} + # Code execution presents tools as Python function signatures to the LLM, which writes + # Python code calling them. Tool names from MCP etc. may not be valid Python + # identifiers (e.g. 'search-records', 'get.data'), so we sanitize them here. + name_map: dict[str, str] = {} # {sanitized: original} + for original_name in wrapped_tools: + sanitized = _sanitize_tool_name(original_name) + base = sanitized + counter = 2 + while sanitized in name_map: + sanitized = f'{base}_{counter}' + counter += 1 + name_map[sanitized] = original_name + + global_sequential = _parallel_execution_mode_ctx_var.get() in ('sequential', 'parallel_ordered_events') + + signatures: list[FunctionSignature] = [] + for sanitized_name, original_name in name_map.items(): + tool_def = wrapped_tools[original_name].tool_def + # TODO: When PR #4755 lands in pydantic-ai-slim, switch to: + # sig = copy.deepcopy(wrapped_tools[original_name].python_signature) + sig = copy.deepcopy(schema_to_signature( + name=original_name, + parameters_schema=tool_def.parameters_json_schema, + description=tool_def.description, + )) + sig.name = sanitized_name + sig.is_async = not (global_sequential or wrapped_tools[original_name].tool_def.sequential) + signatures.append(sig) + + dedup_referenced_types(signatures) + referenced_types = collect_unique_referenced_types(signatures) + else: + name_map = {} + signatures = [] + referenced_types = [] + + run_capability = 'run_python_with_functions' if wrapped_tools else 'run_python' + environment_instructions = self.environment.instructions(run_capability) + if isinstance(self.description, str): + tool_description = build_default_description( + signatures, referenced_types, environment_instructions, description=self.description + ) + else: + tool_description = self.description(signatures, referenced_types, environment_instructions) + + return { + _TOOL_NAME: _CodeExecutionTool( + toolset=self, + signatures=signatures, + referenced_types=referenced_types, + name_map=name_map, + tools=wrapped_tools, + tool_def=ToolDefinition( + name=_TOOL_NAME, + parameters_json_schema=_CODE_JSON_SCHEMA, + description=tool_description, + metadata={'code_arg_name': 'code', 'code_arg_language': 'python'}, + ), + max_retries=self.max_retries, + args_validator=cast(SchemaValidatorProt, _CODE_VALIDATOR), + ) + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + assert name == _TOOL_NAME + assert isinstance(tool, _CodeExecutionTool) + + code = tool_args.get('code') + restart = tool_args.get('restart', False) + + if not code and not restart: + raise ModelRetry('Either `code` or `restart: true` (or both) must be provided.') + + if restart: + self.environment.reset() + + if not code: + return 'Session restarted successfully.' + + # On restart, the REPL state is clean — no accumulated variables from prior + # snippets — so stateless type checking is sound. + if restart: + try: + self.environment.type_check( + code, + signatures=tool.signatures or None, + referenced_types=tool.referenced_types or None, + ) + except CodeTypingError as e: + raise ModelRetry(f'Type error in generated code:\n{e.message}') from e + except CodeSyntaxError as e: + raise ModelRetry(f'Syntax error in generated code:\n{e.message}') from e + + return await self._execute_code(code, tool, ctx) + + async def _execute_code( + self, + code: str, + tool: _CodeExecutionTool[AgentDepsT], + ctx: RunContext[AgentDepsT], + ) -> Any: + """Execute code in the environment, dispatching tool calls if needed.""" + tool_manager: ToolManager[AgentDepsT] | None = None + if self.toolset is not None: + tool_manager = ToolManager( + toolset=self.toolset, + ctx=ctx, + tools=tool.tools, + ) + + async def function_callback(call: FunctionCall) -> Any: + sanitized_name = call.function_name + original_name = tool.name_map.get(sanitized_name, sanitized_name) + + try: + if tool_manager is None: # pragma: no cover + raise ModelRetry('No tools available') + + if call.args: + raise ModelRetry( + 'Positional arguments are not supported in code mode tool calls. All parameters are keyword-only.' + ) + + tool_call = ToolCallPart(tool_name=original_name, args=call.kwargs, tool_call_id=call.call_id) + result = await tool_manager.handle_call(tool_call, wrap_validation_errors=False) + + return tool_return_ta.dump_python(result) + except (CallDeferred, ApprovalRequired): + raise exceptions.UserError( + 'Tool approval and deferral are not yet supported in code execution mode. ' + 'Ensure wrapped tools do not use approval or deferral when used with CodeExecutionToolset.' + ) + except (ModelRetry, ValidationError) as e: + raise CodeRuntimeError(f'Call to {sanitized_name!r} failed: {e}') from e + + try: + if 'run_python_with_functions' in self.environment.capabilities: + return await self.environment.run_python_with_functions( + code, + function_callback=function_callback, + functions={sig.name: sig for sig in tool.signatures}, + referenced_types=tool.referenced_types, + ) + else: + return await self.environment.run_python(code) + except CodeTypingError as e: + raise ModelRetry(f'Type error in generated code:\n{e.message}') from e + except CodeSyntaxError as e: + raise ModelRetry(f'Syntax error in generated code:\n{e.message}') from e + except CodeRuntimeError as e: + raise ModelRetry(f'Runtime error in generated code:\n{e.message}') from e + + +def _sanitize_tool_name(name: str) -> str: + """Convert a tool name to a valid Python identifier. + + Args: + name: The original tool name (may contain hyphens, dots, etc.) + + Returns: + A valid Python identifier in snake_case. + + Examples: + >>> _sanitize_tool_name('search-records') + 'search_records' + >>> _sanitize_tool_name('get.user.data') + 'get_user_data' + >>> _sanitize_tool_name('class') # Python keyword + 'class_' + """ + # Replace common separators with underscores + sanitized = re.sub(r'[-.\s]+', '_', name) + + # Remove any remaining invalid characters (keep alphanumeric and underscore) + sanitized = re.sub(r'[^a-zA-Z0-9_]', '', sanitized) + + # Ensure it doesn't start with a digit + if sanitized and sanitized[0].isdigit(): + sanitized = f'_{sanitized}' + + # Handle empty result + if not sanitized: + sanitized = 'tool' + + # Convert to snake_case if it's camelCase or PascalCase + # Insert underscore before uppercase letters that follow lowercase + sanitized = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', sanitized) + sanitized = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', sanitized).lower() + + # Handle Python keywords by appending underscore + if keyword.iskeyword(sanitized): + sanitized = f'{sanitized}_' + + return sanitized diff --git a/pydantic_harness/toolsets/code_execution/_abstract.py b/pydantic_harness/toolsets/code_execution/_abstract.py new file mode 100644 index 0000000..859ec80 --- /dev/null +++ b/pydantic_harness/toolsets/code_execution/_abstract.py @@ -0,0 +1,49 @@ +"""Data types for the code execution layer. + +This module defines error types, the `FunctionCall` dataclass, and the +`FunctionCallback` type alias used by `CodeExecutionToolset` and execution +environments. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any, TypeAlias + + +@dataclass(frozen=True) +class FunctionCall: + """Represents a call to an external function made by executing code.""" + + call_id: str + function_name: str + args: tuple[Any, ...] = () + kwargs: dict[str, Any] = field(default_factory=dict[str, Any]) + + +class CodeExecutionError(Exception): + """Base for all code execution errors.""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +class CodeSyntaxError(CodeExecutionError): + """The generated code has a syntax error.""" + + +class CodeTypingError(CodeExecutionError): + """The generated code has a type error.""" + + +class CodeRuntimeError(CodeExecutionError): + """The generated code raised an exception at runtime.""" + + +class CodeExecutionTimeout(CodeRuntimeError): + """The code execution exceeded the configured timeout.""" + + +FunctionCallback: TypeAlias = Callable[[FunctionCall], Awaitable[Any]] diff --git a/pydantic_harness/toolsets/code_execution/_driver.py b/pydantic_harness/toolsets/code_execution/_driver.py new file mode 100644 index 0000000..c203e08 --- /dev/null +++ b/pydantic_harness/toolsets/code_execution/_driver.py @@ -0,0 +1,249 @@ +"""Stdio-based sandbox driver for CPython runtimes. + +Self-contained script with zero dependencies beyond the Python 3.10+ stdlib. +Runs inside any sandbox (Docker, Modal, etc.) and communicates with the +host via NDJSON (newline-delimited JSON) over stdin/stdout. + +Protocol: + Host -> Driver: init, result, error + Driver -> Host: call, calls_ready, complete, error + +This file is both a module (importable for path resolution) and an executable +script (python -u _driver.py). +""" + +from __future__ import annotations + +import ast +import asyncio +import json +import sys +import traceback +from typing import Any + +# Protocol channel — overridden in __main__ with an fd-level redirect. +_real_stdout = sys.stdout + + +def _write_msg(msg: dict[str, Any]) -> None: + """Write a single NDJSON message to the real stdout.""" + _real_stdout.write(json.dumps(msg, default=str) + '\n') + _real_stdout.flush() + + +def _transform_last_expr(code: str) -> str: + """If the last statement is an expression, wrap it in a return statement. + + Uses AST parsing so multiline expressions (dicts, lists, etc.) are handled + correctly. String-based approaches break on edge cases. + """ + tree = ast.parse(code) + if tree.body and isinstance(tree.body[-1], ast.Expr): + last_expr = tree.body[-1] + ret = ast.Return(value=last_expr.value) + ast.copy_location(ret, last_expr) + tree.body[-1] = ret + ast.fix_missing_locations(tree) + return ast.unparse(tree) + + +def _build_proxy( + name: str, + loop: asyncio.AbstractEventLoop, + call_counter: list[int], + pending_futures: dict[int, asyncio.Future[Any]], + result_cache: dict[str, Any], + calls_ready_handle: list[asyncio.Handle | None], +) -> Any: + """Build an eager proxy function for a declared tool. + + Calling the proxy immediately sends a call message and returns a future. + `await` just waits for the result. This enables fire-then-await parallelism: + `f1 = tool_a(); f2 = tool_b(); r1 = await f1; r2 = await f2` fires both + calls instantly. + + NOT async — returns a future directly. If this were `async def`, calling + without `await` would return an unstarted coroutine (no parallelism). + + After sending a call message, a `calls_ready` boundary message is + scheduled via `loop.call_soon`. Each proxy call cancels the previous + handle, so exactly one `calls_ready` is emitted after the last + synchronous proxy call in a batch — when the event loop runs on the + next `await`. + """ + + def proxy(*args: Any, **kwargs: Any) -> asyncio.Future[Any]: + call_counter[0] += 1 + cid = call_counter[0] + + if str(cid) in result_cache: + f: asyncio.Future[Any] = loop.create_future() + f.set_result(result_cache[str(cid)]) + return f + + future: asyncio.Future[Any] = loop.create_future() + pending_futures[cid] = future + _write_msg({'type': 'call', 'id': cid, 'function': name, 'args': list(args), 'kwargs': kwargs}) + + # Schedule a calls_ready fence to fire when the event loop runs next. + if calls_ready_handle[0] is not None: + calls_ready_handle[0].cancel() + calls_ready_handle[0] = loop.call_soon(_write_msg, {'type': 'calls_ready'}) + + return future + + return proxy + + +def _compile_code(code: str, code_globals: dict[str, Any]) -> Any | None: + """Parse, transform, and compile LLM code into an async function. + + Returns the compiled async function, or None if a syntax error was + reported (error message already sent to host). + """ + try: + transformed = _transform_last_expr(code) + except SyntaxError as e: + _write_msg({'type': 'error', 'error': str(e), 'error_type': 'syntax'}) + return None + + func_code = 'async def __code__():\n' + for line in transformed.splitlines(): + func_code += f' {line}\n' + + try: + compiled = compile(func_code, '', 'exec') + except SyntaxError as e: # pragma: no cover + _write_msg({'type': 'error', 'error': str(e), 'error_type': 'syntax'}) + return None + + exec(compiled, code_globals) + return code_globals['__code__'] + + +async def _stdin_reader( + reader: asyncio.StreamReader, + pending_futures: dict[int, asyncio.Future[Any]], +) -> None: + """Read result/error messages from stdin and resolve corresponding futures.""" + while True: + raw = await reader.readline() + if not raw: + break + try: + msg = json.loads(raw) + except json.JSONDecodeError: + sys.stderr.write(f'Warning: malformed JSON from host: {raw[:200]!r}\n') + continue + + msg_type = msg.get('type') + cid = msg.get('id') + + if msg_type == 'result' and cid is not None: + future = pending_futures.pop(cid, None) + if future is not None and not future.done(): + future.set_result(msg.get('result')) + elif msg_type == 'error' and cid is not None: + future = pending_futures.pop(cid, None) + if future is not None and not future.done(): + future.set_exception(RuntimeError(msg.get('error', 'Tool error'))) + + +async def _execute(init_msg: dict[str, Any], reader: asyncio.StreamReader) -> None: + """Parse code, build proxies, execute, and report the result.""" + code: str = init_msg.get('code', '') + functions: list[str] = init_msg.get('functions', []) + # TODO(sequential): Read `sequential_functions` from init_msg and build sync proxies + # for them. Sequential proxies should either (a) be async coroutines that await their + # result future inline rather than returning it eagerly, or (b) send a special + # `sync_call` message type where the driver blocks until the result arrives before + # continuing code execution. See MontyEnvironment._execution_loop for reference. + result_cache: dict[str, Any] = init_msg.get('result_cache', {}) + + if not code.strip(): + _write_msg({'type': 'complete', 'result': None}) + return + + loop = asyncio.get_running_loop() + call_counter: list[int] = [0] + pending_futures: dict[int, asyncio.Future[Any]] = {} + calls_ready_handle: list[asyncio.Handle | None] = [None] + + code_globals: dict[str, Any] = {'__builtins__': __builtins__, 'asyncio': asyncio} + for name in functions: + code_globals[name] = _build_proxy(name, loop, call_counter, pending_futures, result_cache, calls_ready_handle) + + code_fn = _compile_code(code, code_globals) + if code_fn is None: + return + + stdin_task = asyncio.create_task(_stdin_reader(reader, pending_futures)) + + try: + result = await code_fn() + _write_msg({'type': 'complete', 'result': result}) + except SyntaxError as e: + _write_msg({'type': 'error', 'error': str(e), 'error_type': 'syntax'}) + except Exception: + _write_msg({'type': 'error', 'error': traceback.format_exc(), 'error_type': 'runtime'}) + finally: + stdin_task.cancel() + try: + await stdin_task + except asyncio.CancelledError: + pass + + +async def _main(proto_stdin: Any) -> None: + """Entry point: connect stdin, read init message, execute code.""" + loop = asyncio.get_running_loop() + reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(reader) + await loop.connect_read_pipe(lambda: protocol, proto_stdin) + + line = await reader.readline() + if not line: # pragma: lax no cover + _write_msg({'type': 'error', 'error': 'No init message received', 'error_type': 'runtime'}) + return + + try: + init_msg = json.loads(line) + except json.JSONDecodeError as e: # pragma: lax no cover + _write_msg({'type': 'error', 'error': f'Invalid init message: {e}', 'error_type': 'runtime'}) + return + + if init_msg.get('type') != 'init': # pragma: lax no cover + _write_msg( + { + 'type': 'error', + 'error': f'Expected init message, got: {init_msg.get("type")}', + 'error_type': 'runtime', + } + ) + return + + await _execute(init_msg, reader) + + +if __name__ == '__main__': + import os + + # Flush Python-level buffers before touching the underlying fds. + sys.stdout.flush() + sys.stderr.flush() + + # Save protocol fds to new fd numbers, then redirect fd 0/1 so LLM code + # (os.write, os.read, subprocesses, C extensions) cannot read/corrupt + # the protocol channel. The duped fds survive the redirect. + _proto_stdin_fd = os.dup(0) + _proto_stdout_fd = os.dup(1) + + os.dup2(2, 1) # fd 1 → stderr + _devnull = os.open(os.devnull, os.O_RDONLY) + os.dup2(_devnull, 0) # fd 0 → /dev/null + os.close(_devnull) + + _proto_stdin = os.fdopen(_proto_stdin_fd, 'rb', buffering=0) + _real_stdout = os.fdopen(_proto_stdout_fd, 'w', buffering=1) + + asyncio.run(_main(_proto_stdin)) diff --git a/pydantic_harness/toolsets/execution_environment.py b/pydantic_harness/toolsets/execution_environment.py new file mode 100644 index 0000000..b39c01a --- /dev/null +++ b/pydantic_harness/toolsets/execution_environment.py @@ -0,0 +1,457 @@ +"""ExecutionEnvironmentToolset — exposes coding-agent-style tools backed by an ExecutionEnvironment.""" + +from __future__ import annotations + +import posixpath +import re +from asyncio import Lock +from collections.abc import Iterator +from contextlib import AsyncExitStack, contextmanager +from contextvars import ContextVar +from typing import Any, Literal + +from typing_extensions import Self + +from pydantic_harness.environments._base import ( + IMAGE_EXTENSIONS, + IMAGE_MEDIA_TYPES, + ExecutionEnvironment, +) +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.messages import BinaryContent +from pydantic_ai.toolsets.function import FunctionToolset + +Capability = Literal[ + 'ls', 'shell', 'read_file', 'write_file', 'edit_file', 'glob', 'grep', 'run_code', 'run_code_with_functions' +] +"""Toolset-level capability used in `include`/`exclude`. + +These are higher-level than the environment's fine-grained capabilities. +The toolset maps these to the appropriate environment capabilities. +""" + +EditStrategy = Literal['replace_str', 'apply_patch'] +"""Specific edit tool strategy. Expanded from the `edit_file` capability.""" + +CodeLanguage = Literal['python', 'typescript'] +"""Code execution language. Expanded from the `run_code` capability.""" + +# Capabilities that are excluded by default (handled by CodeExecutionToolset) +_DEFAULT_EXCLUDE: frozenset[Capability] = frozenset({'run_code'}) + +# Mapping from toolset-level code capabilities to per-language env capabilities +_CODE_CAPABILITY_MAP: dict[str, dict[CodeLanguage, str]] = { + 'run_code': { + 'python': 'run_python', + 'typescript': 'run_typescript', + }, + 'run_code_with_functions': { + 'python': 'run_python_with_functions', + 'typescript': 'run_typescript_with_functions', + }, +} + + +class ExecutionEnvironmentToolset(FunctionToolset[Any]): + """Toolset providing coding-agent-style tools backed by an `ExecutionEnvironment`. + + Tool names and schemas are designed to match what popular coding agents + expose, so models are well-trained on them. + + Tools are dynamically registered based on the environment's `capabilities`, + filtered by `include`/`exclude`. The `run_code` capability is excluded + by default (use `CodeExecutionToolset` for code execution). + + The environment can be: + - Passed directly at construction time (most common) + - Set/overridden via context var using `use_environment()` (for testing or per-call-site config) + + Usage: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness.environments import ExecutionEnvironmentToolset + from pydantic_harness.environments.local import LocalEnvironment + + env = LocalEnvironment() + toolset = ExecutionEnvironmentToolset(env) + + agent = Agent('openai:gpt-5.2', toolsets=[toolset]) + + async with env: + result = await agent.run('Write a script that prints hello') + ``` + """ + + def __init__( + self, + environment: ExecutionEnvironment | None = None, + *, + include: frozenset[Capability] | None = None, + exclude: frozenset[Capability] | None = None, + edit_strategy: EditStrategy | None = None, + code_language: CodeLanguage | None = None, + require_shell_approval: bool = False, + require_write_approval: bool = False, + image_support: bool = True, + max_image_bytes: int = 50 * 1024 * 1024, + max_retries: int = 1, + id: str | None = None, + ): + """Create a new execution environment toolset. + + Args: + environment: The execution environment to use for tool execution. + Can also be set later via `use_environment()`. + include: Capabilities to include. `None` means all capabilities + from the environment (minus `run_code`). Pass an explicit set + to restrict to specific capabilities. + exclude: Capabilities to exclude. `None` defaults to `{'run_code'}`. + Use `frozenset()` to include all capabilities including `run_code`. + edit_strategy: Which edit strategy to use. `None` auto-selects + `'replace_str'` if supported by the environment. + code_language: Code execution language. `None` auto-detects + from the environment's capabilities (defaults to `'python'`). + require_shell_approval: Whether the `shell` tool requires human-in-the-loop + approval before execution. Recommended for `LocalEnvironment` where + commands run directly on the host. + require_write_approval: Whether `write_file` and edit tools require + human-in-the-loop approval before execution. + image_support: Whether `read_file` should return images as `BinaryContent` + for multimodal models (otherwise returns a placeholder message). + max_image_bytes: Maximum image file size to return as BinaryContent. + max_retries: Maximum retries per tool call. + id: Optional unique ID for the toolset (required for durable execution). + """ + super().__init__(max_retries=max_retries, id=id) + self._default_environment = environment + self._environment_override: ContextVar[ExecutionEnvironment | None] = ContextVar( + f'_environment_override_{id or "environment"}', default=None + ) + self._include = include + self._exclude = exclude if exclude is not None else _DEFAULT_EXCLUDE + self._edit_strategy: EditStrategy | None = edit_strategy + self._code_language: CodeLanguage | None = code_language + self._image_support = image_support + self._max_image_bytes = max_image_bytes + self._require_shell_approval = require_shell_approval + self._require_write_approval = require_write_approval + self._enter_lock: Lock = Lock() + self._running_count: int = 0 + self._exit_stack: AsyncExitStack | None = None + + # Register tools based on what we know at init time. + # If no environment is provided, we register a full set of tools and + # let runtime errors catch unsupported capabilities. + self._register_tools(environment) + + def _resolve_capabilities(self, env: ExecutionEnvironment | None) -> set[Capability]: + """Determine which toolset-level capabilities to register as tools.""" + if env is not None: + env_caps = env.capabilities + available: set[Capability] = set() + # Map env capabilities back to toolset capabilities + for cap in ('ls', 'shell', 'read_file', 'write_file', 'glob', 'grep'): + if cap in env_caps: + available.add(cap) + # Check for edit_file: env has replace_str or apply_patch + if 'replace_str' in env_caps or 'apply_patch' in env_caps: + available.add('edit_file') + # Check for run_code: env has run_python or run_typescript + if 'run_python' in env_caps or 'run_typescript' in env_caps: + available.add('run_code') + # Check for run_code_with_functions + if 'run_python_with_functions' in env_caps or 'run_typescript_with_functions' in env_caps: + available.add('run_code_with_functions') + else: + # No environment yet — register everything (runtime will error on unsupported) + available = {'ls', 'shell', 'read_file', 'write_file', 'edit_file', 'glob', 'grep'} + + if self._include is not None: + available &= self._include + + available -= self._exclude + return available + + def _resolve_edit_tool(self, env: ExecutionEnvironment | None) -> EditStrategy | None: + """Determine which edit strategy to use.""" + if self._edit_strategy is not None: + return self._edit_strategy + if env is not None: + env_caps = env.capabilities + if 'replace_str' in env_caps: + return 'replace_str' + if 'apply_patch' in env_caps: + return 'apply_patch' + return None + # Default when no environment is available + return 'replace_str' + + def _register_tools(self, env: ExecutionEnvironment | None) -> None: + """Register tools dynamically based on capabilities.""" + caps = self._resolve_capabilities(env) + + if 'ls' in caps: + self._register_ls() + if 'shell' in caps: + self._register_shell() + if 'read_file' in caps: + self._register_read_file() + if 'write_file' in caps: + self._register_write_file() + if 'edit_file' in caps: + edit_strategy = self._resolve_edit_tool(env) + if edit_strategy == 'replace_str': + self._register_replace_str() + if 'glob' in caps: + self._register_glob() + if 'grep' in caps: + self._register_grep() + + def _register_ls(self) -> None: + async def ls(path: str = '.') -> str: + """List directory contents. + + Args: + path: The directory path to list. Defaults to the working directory. + """ + try: + entries = await self.required_environment.ls(path) + except (NotADirectoryError, PermissionError, OSError) as e: + return f'Error: {e}' + if not entries: + return 'Empty directory.' + lines: list[str] = [] + for entry in entries: + if entry.is_dir: + lines.append(f'{entry.name}/') + elif entry.size is not None: + lines.append(f'{entry.name} ({entry.size} bytes)') + else: + lines.append(entry.name) + return '\n'.join(lines) + + self.tool_plain(ls) + + def _register_shell(self) -> None: + async def shell(command: str, timeout: int = 120) -> str: + """Execute a shell command and return its output. + + Use this for running scripts, installing packages, and other terminal operations. + + Args: + command: The shell command to execute. + timeout: Maximum seconds to wait for the command to complete. + """ + result = await self.required_environment.shell(command, timeout=timeout) + parts: list[str] = [] + if result.output: + parts.append(result.output) + if result.truncated: + parts.append('[output truncated]') + parts.append(f'Exit code: {result.exit_code}') + return '\n'.join(parts) + + self.tool_plain(requires_approval=self._require_shell_approval)(shell) + + def _register_read_file(self) -> None: + async def read_file(path: str, offset: int = 0, limit: int = 2000) -> Any: + """Read a file from the filesystem. + + Returns text files with line numbers, or renders image files for visual inspection. + Use offset and limit to read specific sections of large files. + + Args: + path: The file path to read. + offset: The line number to start reading from (0-indexed). + limit: Maximum number of lines to read. + """ + try: + content = await self.required_environment.read_file(path, offset=offset, limit=limit) + if isinstance(content, bytes): + ext = posixpath.splitext(path)[1].lower() + if ext in IMAGE_EXTENSIONS: + # Image file — return as BinaryContent or placeholder + if self._image_support: + if len(content) > self._max_image_bytes: + return ( + f'Error: Image too large ({len(content)} bytes, max {self._max_image_bytes} bytes).' + ) + media_type = IMAGE_MEDIA_TYPES.get(ext, 'application/octet-stream') + return BinaryContent(data=content, media_type=media_type) + else: + return f'[Image file: {path} — image_support is disabled on this toolset]' + else: + return f'[Binary file: {path} — cannot display as text]' + return content + except (FileNotFoundError, PermissionError, ValueError, OSError) as e: + return f'Error: {e}' + + self.tool_plain(read_file) + + def _register_write_file(self) -> None: + async def write_file(path: str, content: str) -> str: + """Create or overwrite a file. + + The file and any parent directories will be created if they do not exist. + + Args: + path: The file path to write. + content: The content to write to the file. + """ + try: + await self.required_environment.write_file(path, content) + return f'File written: {path}' + except (PermissionError, OSError) as e: + return f'Error: {e}' + + self.tool_plain(requires_approval=self._require_write_approval)(write_file) + + def _register_replace_str(self) -> None: + async def replace_str(path: str, old: str, new: str, replace_all: bool = False) -> str: + """Edit a file by exact string replacement. + + The old string must match exactly (including whitespace and indentation). + For uniqueness, include surrounding context lines. + Only use this after reading the file first. + + Args: + path: The file path to edit. + old: The exact text to find and replace. + new: The replacement text. + replace_all: Replace all occurrences. Defaults to false (old must be unique). + """ + try: + count = await self.required_environment.replace_str(path, old, new, replace_all=replace_all) + return f'Replaced {count} occurrence{"s" if count != 1 else ""} in {path}.' + except (FileNotFoundError, ValueError) as e: + raise ModelRetry(str(e)) + + self.tool_plain(requires_approval=self._require_write_approval)(replace_str) + + def _register_glob(self) -> None: + async def glob_tool(pattern: str, path: str = '.') -> str: + """Find files matching a glob pattern. + + Supports patterns like `**/*.py`, `src/**/*.ts`. + Returns up to 100 matching file paths. + + Args: + pattern: The glob pattern to match files against. + path: The directory to search in. Defaults to the working directory. + """ + try: + matches = await self.required_environment.glob(pattern, path=path) + except (PermissionError, OSError) as e: + return f'Error: {e}' + if not matches: + return 'No files found.' + truncated = len(matches) > 100 + matches = matches[:100] + result = '\n'.join(matches) + if truncated: + result += '\n[... truncated, showing first 100 matches]' + return result + + self.tool_plain(name='glob')(glob_tool) + + def _register_grep(self) -> None: + async def grep_tool( + pattern: str, + path: str | None = None, + glob: str | None = None, + output_mode: Literal['content', 'files_with_matches', 'count'] = 'content', + ) -> str: + """Search file contents with a regex pattern. + + Args: + pattern: The regex pattern to search for. + path: The file or directory to search in. + glob: Glob pattern to filter which files are searched (e.g. `*.py`). + output_mode: Controls output format: + `content` (default) shows matching lines with file paths and line numbers, + `files_with_matches` shows only file paths, + `count` shows match counts per file. + """ + try: + result = await self.required_environment.grep( + pattern, path=path, glob_pattern=glob, output_mode=output_mode + ) + except (PermissionError, OSError, re.error) as e: + return f'Error: {e}' + if not result.strip(): + return 'No matches found.' + return result + + self.tool_plain(name='grep')(grep_tool) + + @property + def tool_name_conflict_hint(self) -> str: + return 'Wrap the ExecutionEnvironmentToolset in a PrefixedToolset to avoid name conflicts.' + + @property + def environment(self) -> ExecutionEnvironment | None: + """The active execution environment, or None if not configured. + + Checks the context var override first, then falls back to the default. + """ + override = self._environment_override.get() + if override is not None: + return override + return self._default_environment + + @property + def required_environment(self) -> ExecutionEnvironment: + """The active execution environment, raising if not configured. + + Raises: + RuntimeError: If no environment is available. + """ + env = self.environment + if env is not None: + return env + raise RuntimeError( + 'No execution environment configured. Pass one to ExecutionEnvironmentToolset() or use .use_environment().' + ) + + @contextmanager + def use_environment(self, environment: ExecutionEnvironment) -> Iterator[None]: + """Override the execution environment for the current context. + + Useful for testing or using different environments at different call sites. + + Usage: + ```python {test="skip" lint="skip"} + with toolset.use_environment(test_env): + result = await agent.run('test prompt', toolsets=[toolset]) + ``` + + Args: + environment: The execution environment to use within this context. + """ + token = self._environment_override.set(environment) + try: + yield + finally: + self._environment_override.reset(token) + + # --- Lifecycle --- + + async def __aenter__(self) -> Self: + async with self._enter_lock: + self._running_count += 1 + if self._running_count == 1: + self._exit_stack = AsyncExitStack() + try: + await self._exit_stack.enter_async_context(self.required_environment) + except Exception: + self._running_count -= 1 + raise + return self + + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._running_count -= 1 + if self._running_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + return None diff --git a/pyproject.toml b/pyproject.toml index 8f74532..f882e9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,14 @@ classifiers = [ 'Topic :: Software Development :: Libraries', 'Typing :: Typed', ] -dependencies = ['pydantic-ai-slim>=0.1'] +dependencies = [ + 'pydantic-ai-slim>=1.71.0', + 'pydantic>=2.12', + 'anyio>=4.5.0', +] + +[project.optional-dependencies] +monty = ['pydantic-monty>=0.0.5'] [project.urls] Homepage = 'https://github.com/pydantic/pydantic-harness' @@ -33,9 +40,14 @@ Issues = 'https://github.com/pydantic/pydantic-harness/issues' [dependency-groups] dev = [ - 'pytest', + 'pytest>=9.0.0', 'pytest-xdist', + 'pytest-mock>=3.14.0', + 'anyio>=4.5.0', 'coverage', + 'inline-snapshot>=0.19.3', + 'dirty-equals>=0.9.0', + 'pydantic-monty>=0.0.5', ] lint = [ 'ruff>=0.14', @@ -51,7 +63,7 @@ style = 'pep440' bump = true [tool.hatch.build.targets.wheel] -packages = ['src/pydantic_harness'] +packages = ['pydantic_harness'] [tool.ruff] line-length = 120 @@ -80,6 +92,7 @@ pythonVersion = '3.10' typeCheckingMode = 'strict' [tool.pytest.ini_options] +testpaths = ['tests'] xfail_strict = true filterwarnings = ['error'] diff --git a/tests/code_execution/__init__.py b/tests/code_execution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/code_execution/conftest.py b/tests/code_execution/conftest.py new file mode 100644 index 0000000..e5321e9 --- /dev/null +++ b/tests/code_execution/conftest.py @@ -0,0 +1,138 @@ +"""Shared fixtures and test tools for code execution tests.""" + +from __future__ import annotations + +import shutil +import subprocess +from collections.abc import Callable +from typing import Any + +import pytest +from typing_extensions import TypedDict + +from pydantic_harness._python_signature import FunctionSignature, TypeSignature +from pydantic_ai._run_context import RunContext +from pydantic_harness.environments._base import ExecutionEnvironment +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import Tool +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_harness.toolsets.code_execution._abstract import FunctionCallback +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import RunUsage + + +# Define return type as TypedDict for better type hints in signatures +class WeatherResult(TypedDict): + """Weather data for a city.""" + + city: str + temperature: float + unit: str + conditions: str + + +# Simulated weather data for test cities +_WEATHER_DATA: dict[str, dict[str, Any]] = { + 'London': {'temperature': 15.0, 'conditions': 'cloudy'}, + 'Paris': {'temperature': 18.0, 'conditions': 'sunny'}, + 'Tokyo': {'temperature': 22.0, 'conditions': 'rainy'}, + 'New York': {'temperature': 12.0, 'conditions': 'windy'}, + 'Sydney': {'temperature': 25.0, 'conditions': 'sunny'}, +} + + +def get_weather(city: str) -> WeatherResult: + """Get weather for a city. + + Args: + city: Name of the city to get weather for. + + Returns: + Weather data including temperature and conditions. + """ + data = _WEATHER_DATA.get(city, {'temperature': 20.0, 'conditions': 'unknown'}) + return {'city': city, 'temperature': data['temperature'], 'unit': 'celsius', 'conditions': data['conditions']} + + +def build_run_context() -> RunContext[None]: + """Build a minimal RunContext for direct call_tool tests.""" + return RunContext( + deps=None, + model=TestModel(), + usage=RunUsage(), + prompt=None, + messages=[], + run_step=0, + ) + + +async def build_code_execution_toolset( + environment: ExecutionEnvironment, + *tools: Tool[Any] | tuple[Callable[..., Any], bool], +) -> tuple[CodeExecutionToolset[None], dict[str, Any]]: + """Build and initialize a CodeExecutionToolset, returning it along with its tools dict.""" + toolset: FunctionToolset[None] = FunctionToolset() + for tool in tools: + if isinstance(tool, Tool): + toolset.add_tool(tool) + else: + func, takes_ctx = tool + toolset.add_function(func, takes_ctx=takes_ctx) + code_execution = CodeExecutionToolset(environment, toolset=toolset) + ctx = build_run_context() + tool_defs = await code_execution.get_tools(ctx) + return code_execution, tool_defs + + +async def run_code_with_tools( + code: str, + environment: ExecutionEnvironment, + *tools: Tool[Any] | tuple[Callable[..., Any], bool], +) -> Any: + """Run code through CodeExecutionToolset. Each tool is a Tool object or (function, takes_ctx) tuple.""" + code_execution, tool_defs = await build_code_execution_toolset(environment, *tools) + ctx = build_run_context() + return await code_execution.call_tool('run_code', {'code': code}, ctx, tool_defs['run_code']) + + +class StubEnvironment(ExecutionEnvironment): + """Minimal ExecutionEnvironment for testing CodeExecutionToolset logic without pydantic-monty.""" + + @property + def capabilities(self) -> frozenset[Any]: + return frozenset({'run_python', 'run_python_with_functions'}) + + async def run_python_with_functions( + self, + code: str, + *, + function_callback: FunctionCallback, + functions: dict[str, FunctionSignature] | None = None, + referenced_types: list[TypeSignature] | None = None, + ) -> Any: + raise NotImplementedError('StubEnvironment does not execute code') + + +def _docker_is_available() -> bool: + """Check whether Docker is installed and the daemon is reachable.""" + if not shutil.which('docker'): # pragma: lax no cover + return False + try: # pragma: lax no cover + subprocess.run(['docker', 'info'], check=True, capture_output=True) # pragma: lax no cover + except (subprocess.CalledProcessError, FileNotFoundError): # pragma: lax no cover + return False + return True # pragma: lax no cover + + +@pytest.fixture(params=['monty']) +def code_environment(request: pytest.FixtureRequest) -> ExecutionEnvironment: + """Parameterized fixture providing each ExecutionEnvironment implementation for code execution.""" + if request.param == 'monty': + try: + from pydantic_harness.environments.monty import MontyEnvironment + except ImportError: + pytest.skip('pydantic-monty is not installed') + + return MontyEnvironment() + + pytest.skip(f'Unknown environment: {request.param}') # pragma: no cover diff --git a/tests/code_execution/test_code_execution.py b/tests/code_execution/test_code_execution.py new file mode 100644 index 0000000..cb0e59a --- /dev/null +++ b/tests/code_execution/test_code_execution.py @@ -0,0 +1,99 @@ +"""Test basic code execution: tool calls, parallelism, and error handling. + +Parameterized across all ExecutionEnvironment implementations. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from pydantic_harness.environments._base import ExecutionEnvironment +from pydantic_ai.exceptions import ModelRetry + +from .conftest import run_code_with_tools + +pytestmark = pytest.mark.anyio + + +async def test_simple_execution(code_environment: ExecutionEnvironment): + """Valid code + tool call executes and returns result.""" + + def add(x: int, y: int) -> int: + return x + y + + result = await run_code_with_tools('await add(x=1, y=2)', code_environment, (add, False)) + assert result == 3 + + +async def test_parallel_execution(code_environment: ExecutionEnvironment): + """Fire-then-await runs tools in parallel.""" + + def slow_op(name: str) -> str: + return f'done:{name}' + + code = 'f1 = slow_op(name="a")\nf2 = slow_op(name="b")\nr1 = await f1\nr2 = await f2\n[r1, r2]' + result = await run_code_with_tools(code, code_environment, (slow_op, False)) + assert result == ['done:a', 'done:b'] + + +async def test_parallel_execution_gather(code_environment: ExecutionEnvironment): + """asyncio.gather runs tools in parallel.""" + + def slow_op(name: str) -> str: + return f'done:{name}' + + code = 'results = await asyncio.gather(slow_op(name="a"), slow_op(name="b"))\nlist(results)' + result = await run_code_with_tools(code, code_environment, (slow_op, False)) + assert result == ['done:a', 'done:b'] + + +async def test_no_function_calls(code_environment: ExecutionEnvironment): + """Code without function calls executes locally and returns result.""" + result = await run_code_with_tools('1 + 2', code_environment) + assert result == 3 + + +async def test_syntax_error_raises_model_retry(code_environment: ExecutionEnvironment): + """Syntax errors raise ModelRetry so the LLM can fix them.""" + with pytest.raises(ModelRetry): + await run_code_with_tools('def while invalid', code_environment) + + +async def test_runtime_error_raises_model_retry(code_environment: ExecutionEnvironment): + """Runtime exceptions raise ModelRetry so the LLM can fix them.""" + with pytest.raises(ModelRetry): + await run_code_with_tools('1 / 0', code_environment) + + +async def test_tool_exception_propagates(code_environment: ExecutionEnvironment): + """Tool exceptions propagate and crash the run, consistent with normal tool execution.""" + + def failing_tool() -> str: + raise ValueError('tool bug') + + with pytest.raises(ValueError, match='tool bug'): + await run_code_with_tools('await failing_tool()', code_environment, (failing_tool, False)) + + +async def test_execution_timeout_raises_model_retry(code_environment: Any): + """Execution timeout raises ModelRetry so the LLM is informed.""" + code_environment.execution_timeout = 1.0 + with pytest.raises(ModelRetry, match='timed out'): + await run_code_with_tools('while True: pass', code_environment) + + +async def test_positional_args_raise_model_retry(code_environment: ExecutionEnvironment): + """Positional arguments in code mode tool calls raise ModelRetry. + + Monty catches this at type-check time (too many positional arguments); + Docker catches it at call_tool_callback time (positional args not supported). + Either way the LLM gets a ModelRetry. + """ + + def add(x: int, y: int) -> int: + return x + y # pragma: no cover + + with pytest.raises(ModelRetry): + await run_code_with_tools('await add(1, 2)', code_environment, (add, False)) diff --git a/tests/code_execution/test_code_execution_integration.py b/tests/code_execution/test_code_execution_integration.py new file mode 100644 index 0000000..a41bb1b --- /dev/null +++ b/tests/code_execution/test_code_execution_integration.py @@ -0,0 +1,337 @@ +"""End-to-end integration tests: Agent + CodeExecutionToolset + FunctionModel. + +Validates the full pipeline including how the agent loop interacts with +code execution tool routing, tool execution, and result handling. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import Agent +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) +from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.usage import RequestUsage + +try: + from pydantic_harness.environments import monty as _monty # pyright: ignore[reportUnusedImport] # noqa: F401 +except ImportError: # pragma: lax no cover + pytest.skip('pydantic-monty is not installed', allow_module_level=True) +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from dirty_equals import IsDatetime, IsStr + +pytestmark = pytest.mark.anyio + + +def _make_toolset() -> FunctionToolset[None]: + """Build a simple FunctionToolset with weather + math tools.""" + + def get_weather(city: str) -> dict[str, Any]: + """Get weather for a city.""" + return {'city': city, 'temp': 20, 'conditions': 'sunny'} + + def add(x: int, y: int) -> int: + """Add two numbers.""" + return x + y + + toolset: FunctionToolset[None] = FunctionToolset() + toolset.add_function(get_weather, takes_ctx=False) + toolset.add_function(add, takes_ctx=False) + return toolset + + +async def test_agent_single_tool_call(): + """Agent calls run_code with a single tool call and returns the result as text.""" + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if not any(isinstance(m, ModelResponse) for m in messages): + # First turn: call run_code + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_code', + args={'code': 'await get_weather(city="Paris")'}, + ) + ] + ) + # After tool result: return text + return ModelResponse(parts=[TextPart('The weather in Paris is sunny at 20 degrees.')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset(toolset=_make_toolset())], + ) + result = await agent.run('What is the weather in Paris?') + assert result.output == 'The weather in Paris is sunny at 20 degrees.' + + +async def test_agent_multiple_tool_calls(): + """Agent writes code that calls multiple tools and processes results.""" + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if not any(isinstance(m, ModelResponse) for m in messages): + code = """\ +w = await get_weather(city="London") +total = await add(x=w["temp"], y=5) +{"adjusted_temp": total, "city": w["city"]}""" + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': code})]) + return ModelResponse(parts=[TextPart('London is 25 degrees adjusted.')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset(toolset=_make_toolset())], + ) + result = await agent.run('Adjusted temp for London?') + assert result.output == 'London is 25 degrees adjusted.' + + +async def test_agent_parallel_fire_then_await(): + """Agent writes code using fire-then-await for parallel execution.""" + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if not any(isinstance(m, ModelResponse) for m in messages): + code = """\ +f1 = get_weather(city="Paris") +f2 = get_weather(city="Tokyo") +r1 = await f1 +r2 = await f2 +[r1["city"], r2["city"]]""" + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': code})]) + return ModelResponse(parts=[TextPart('Got weather for both cities.')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset(toolset=_make_toolset())], + ) + result = await agent.run('Weather in Paris and Tokyo?') + assert result.output == 'Got weather for both cities.' + + +async def test_agent_code_error_triggers_retry(): + """Syntax/runtime errors in code trigger ModelRetry, and the agent can recover.""" + call_count = 0 + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + # First attempt: bad code + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': '1 / 0'})]) + if call_count == 2: + # Second attempt after retry: good code + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'await add(x=1, y=2)'})]) + # Final: return text + return ModelResponse(parts=[TextPart('The answer is 3.')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset(toolset=_make_toolset())], + ) + result = await agent.run('Add 1 and 2') + assert result.output == 'The answer is 3.' + assert call_count == 3 + + +@pytest.mark.xfail( + reason='MontyRepl mutex prevents concurrent runs on shared environment. ' + 'Needs for_run lifecycle hook from PR #4688 to create per-run REPL instances.', + raises=RuntimeError, +) +async def test_concurrent_agent_runs_on_shared_toolset(): + """Two concurrent agent.run() calls sharing a CodeExecutionToolset produce correct independent results.""" + + def add(x: int, y: int) -> int: + return x + y + + def mul(x: int, y: int) -> int: + return x * y + + toolset: FunctionToolset[None] = FunctionToolset() + toolset.add_function(add, takes_ctx=False) + toolset.add_function(mul, takes_ctx=False) + + shared_toolset = CodeExecutionToolset(toolset=toolset) + + def add_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if not any(isinstance(m, ModelResponse) for m in messages): + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'await add(x=1, y=2)'})]) + return ModelResponse(parts=[TextPart('3')]) + + def mul_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if not any(isinstance(m, ModelResponse) for m in messages): + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'await mul(x=3, y=4)'})]) + return ModelResponse(parts=[TextPart('12')]) + + agent_add = Agent(FunctionModel(add_model), toolsets=[shared_toolset]) + agent_mul = Agent(FunctionModel(mul_model), toolsets=[shared_toolset]) + + r1, r2 = await asyncio.gather( + agent_add.run('add 1 and 2'), + agent_mul.run('multiply 3 and 4'), + ) + assert r1.output == '3' + assert r2.output == '12' + assert r1.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='add 1 and 2', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_code', + args={'code': 'await add(x=1, y=2)'}, + tool_call_id=IsStr(), + ) + ], + usage=RequestUsage(input_tokens=54, output_tokens=7), + model_name='function:add_model:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='run_code', + content=3, + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='3')], + usage=RequestUsage(input_tokens=55, output_tokens=8), + model_name='function:add_model:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + assert r2.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='multiply 3 and 4', timestamp=IsDatetime())], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_code', + args={'code': 'await mul(x=3, y=4)'}, + tool_call_id=IsStr(), + ) + ], + usage=RequestUsage(input_tokens=54, output_tokens=7), + model_name='function:mul_model:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='run_code', + content=12, + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ModelResponse( + parts=[TextPart(content='12')], + usage=RequestUsage(input_tokens=55, output_tokens=8), + model_name='function:mul_model:', + timestamp=IsDatetime(), + run_id=IsStr(), + ), + ] + ) + + +async def test_agent_no_toolset_pure_code_execution(): + """Agent with CodeExecutionToolset() and no wrapped toolset executes pure Python.""" + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if not any(isinstance(m, ModelResponse) for m in messages): + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': '2 ** 10'})]) + return ModelResponse(parts=[TextPart('1024')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset()], + ) + result = await agent.run('What is 2^10?') + assert result.output == '1024' + + +async def test_agent_restart_clears_state(): + """Agent can use restart=True to clear REPL state and start fresh.""" + call_count = 0 + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + # First: set a variable + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'x = 42'})]) + if call_count == 2: + # Second: restart and try to use x (will get NameError → retry) + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'x', 'restart': True})]) + if call_count == 3: + # Third: after retry, define x fresh + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'x = 99\nx'})]) + return ModelResponse(parts=[TextPart('x is 99')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset()], + ) + result = await agent.run('test restart') + assert result.output == 'x is 99' + assert call_count == 4 + + +async def test_agent_restart_only(): + """Agent can use restart=True without code to just reset the session.""" + call_count = 0 + + def model_function(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + if call_count == 1: + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': 'x = 42'})]) + if call_count == 2: + # Just restart, no code + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'restart': True})]) + if call_count == 3: + # Fresh session + return ModelResponse(parts=[ToolCallPart(tool_name='run_code', args={'code': '"clean"'})]) + return ModelResponse(parts=[TextPart('done')]) + + agent = Agent( + FunctionModel(model_function), + toolsets=[CodeExecutionToolset()], + ) + result = await agent.run('test restart only') + assert result.output == 'done' diff --git a/tests/code_execution/test_code_execution_toolset.py b/tests/code_execution/test_code_execution_toolset.py new file mode 100644 index 0000000..7f7a091 --- /dev/null +++ b/tests/code_execution/test_code_execution_toolset.py @@ -0,0 +1,476 @@ +"""Tests for CodeExecutionToolset logic (description, caching, name collisions, deferred tools). + +Uses StubEnvironment so these tests don't require pydantic-monty or Docker. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from pydantic_harness._python_signature import FunctionSignature, TypeSignature, collect_unique_referenced_types +from pydantic_ai.exceptions import UserError +from pydantic_harness.toolsets.code_execution import CodeExecutionToolset +from pydantic_ai.toolsets.function import FunctionToolset + +from .conftest import StubEnvironment, WeatherResult, build_code_execution_toolset, build_run_context, get_weather + +pytestmark = pytest.mark.anyio + + +def _add(*, x: int, y: int) -> int: + """Add two integers.""" + return x + y # pragma: no cover + + +def _get_weather_alias(city: str) -> WeatherResult: + """Get weather (alias).""" + return get_weather(city) # pragma: no cover + + +async def test_get_tools_produces_single_code_tool(): + """get_tools() returns exactly one tool named 'run_code'.""" + _, tools = await build_code_execution_toolset(StubEnvironment(), (_add, False)) + assert list(tools.keys()) == ['run_code'] + + +async def test_description_default(): + """Default description includes preamble and function signatures but no environment instructions.""" + _, tools = await build_code_execution_toolset(StubEnvironment(), (_add, False)) + description = tools['run_code'].tool_def.description or '' + # Preamble present + assert 'run Python code' in description + # Function signature present + assert 'async def _add' in description + # No environment instructions (StubEnvironment.instructions returns None) + assert 'restricted Python subset' not in description + + +async def test_description_custom_string(): + """A custom string replaces the default preamble.""" + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(StubEnvironment(), toolset=ts, description='My preamble') + tools = await cm.get_tools(build_run_context()) + assert 'My preamble' in (tools['run_code'].tool_def.description or '') + + +async def test_description_custom_callback(): + """A callback gets full control over the description.""" + + def my_desc(sigs: list[FunctionSignature], types: list[TypeSignature], instructions: str | None) -> str: + return f'{len(sigs)} tools' + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(StubEnvironment(), toolset=ts, description=my_desc) + tools = await cm.get_tools(build_run_context()) + assert tools['run_code'].tool_def.description == '1 tools' + + +async def test_deferred_tools_raise_user_error(): + """Wrapping a tool with requires_approval=True triggers UserError in get_tools().""" + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False, requires_approval=True) + cm = CodeExecutionToolset(StubEnvironment(), toolset=ts) + with pytest.raises(UserError, match='approval and deferral are not yet supported'): + await cm.get_tools(build_run_context()) + + +async def test_name_collision_counter(): + """3 tools sanitizing to the same base name get _2/_3 suffixes.""" + + def my_tool(*, x: int) -> int: + """First.""" + return x # pragma: no cover + + # These have different __name__ but sanitize to same snake_case + # Build toolset manually with colliding sanitized names + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(my_tool, name='my-tool', takes_ctx=False) + ts.add_function(my_tool, name='my.tool', takes_ctx=False) + ts.add_function(my_tool, name='my tool', takes_ctx=False) + cm = CodeExecutionToolset(StubEnvironment(), toolset=ts) + tools = await cm.get_tools(build_run_context()) + description = tools['run_code'].tool_def.description or '' + assert 'async def my_tool(' in description + assert 'async def my_tool_2(' in description + assert 'async def my_tool_3(' in description + + +async def test_cached_signature_reused_across_get_tools_calls(): + """Calling get_tools() twice reuses the same cached FunctionSignature objects.""" + code_execution, tools1 = await build_code_execution_toolset(StubEnvironment(), (_add, False)) + + ctx = build_run_context() + tools2 = await code_execution.get_tools(ctx) + + tool1 = tools1['run_code'] + tool2 = tools2['run_code'] + + # The code execution tool descriptions should match + assert tool1.tool_def.description == tool2.tool_def.description + + +async def test_dedup_correctness_after_cache_backed_deepcopy(): + """Multiple tools with shared types produce correct dedup after cache-backed deepcopy.""" + _code_execution, tools = await build_code_execution_toolset( + StubEnvironment(), (get_weather, False), (_get_weather_alias, False) + ) + tool = tools['run_code'] + + # Both tools reference the same WeatherResult type -- dedup should unify them + unique_types = collect_unique_referenced_types(tool.signatures) + weather_types = [t for t in unique_types if t.name == 'WeatherResult'] + assert len(weather_types) == 1 + + +async def test_aenter_cleanup_on_wrapped_failure(): + """If wrapped toolset's __aenter__ raises, environment is cleaned up.""" + from pydantic_ai.toolsets.abstract import AbstractToolset + + class FailingToolset(AbstractToolset[None]): + @property + def id(self) -> str | None: + return None + + async def __aenter__(self) -> Any: + raise RuntimeError('wrapped failed') + + async def __aexit__(self, *args: Any) -> None: + pass # pragma: no cover + + async def get_tools(self, ctx: Any) -> dict[str, Any]: + return {} # pragma: no cover + + async def call_tool(self, name: str, tool_args: Any, ctx: Any, tool: Any) -> None: + pass # pragma: no cover + + enter_count = 0 + exit_count = 0 + + class TrackingEnvironment(StubEnvironment): + async def __aenter__(self) -> Any: + nonlocal enter_count + enter_count += 1 + return self + + async def __aexit__(self, *args: Any) -> None: + nonlocal exit_count + exit_count += 1 + + failing = FailingToolset() + assert failing.id is None # verify the property works + + cm = CodeExecutionToolset(TrackingEnvironment(), toolset=failing) + with pytest.raises(RuntimeError, match='wrapped failed'): + await cm.__aenter__() + assert enter_count == 1 + assert exit_count == 1 # environment was cleaned up + + +async def test_call_deferred_during_execution(monkeypatch: pytest.MonkeyPatch): + """CallDeferred during tool execution raises UserError.""" + from pydantic_ai.exceptions import CallDeferred, UserError + from pydantic_harness.toolsets.code_execution import CodeRuntimeError, FunctionCall + + call_made = False + + class ExecutingEnvironment(StubEnvironment): + async def run_python_with_functions( + self, code: str, *, function_callback: Any, functions: Any = None, referenced_types: Any = None + ) -> Any: + nonlocal call_made + call = FunctionCall(call_id='1', function_name='_add', args=(), kwargs={'x': 1, 'y': 2}) + try: + await function_callback(call) + except UserError: + call_made = True + raise CodeRuntimeError('deferred') + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(ExecutingEnvironment(), toolset=ts) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + tool = tools['run_code'] + + # Monkeypatch handle_call to raise CallDeferred + from pydantic_ai._tool_manager import ToolManager + + async def raising_handle(self: Any, tool_call: Any, **kwargs: Any) -> Any: + raise CallDeferred() + + monkeypatch.setattr(ToolManager, 'handle_call', raising_handle) + + from pydantic_ai.exceptions import ModelRetry + + with pytest.raises(ModelRetry): + await cm.call_tool('run_code', {'code': 'await _add(x=1, y=2)'}, ctx, tool) + assert call_made + + +def test_get_weather_helper(): + """Verify get_weather fixture helper returns expected data.""" + result = get_weather('London') + assert result == {'city': 'London', 'temperature': 15.0, 'unit': 'celsius', 'conditions': 'cloudy'} + + # Unknown city uses fallback data + result = get_weather('Atlantis') + assert result == {'city': 'Atlantis', 'temperature': 20.0, 'unit': 'celsius', 'conditions': 'unknown'} + + +async def test_no_toolset_produces_run_code_tool(): + """CodeExecutionToolset() with no toolset still produces a 'run_code' tool.""" + cm = CodeExecutionToolset(StubEnvironment()) + tools = await cm.get_tools(build_run_context()) + assert list(tools.keys()) == ['run_code'] + + +async def test_no_toolset_description_omits_tool_calling(): + """With no toolset, the description uses the base prompt without mentioning tool calling.""" + cm = CodeExecutionToolset(StubEnvironment()) + tools = await cm.get_tools(build_run_context()) + description = tools['run_code'].tool_def.description or '' + # Base prompt is present + assert 'run Python code' in description + # Tools prompt elements are absent + assert 'call other tools as functions' not in description + assert 'Available functions' not in description + assert 'async def' not in description + + +# --- Environment factory and compatibility tests --- + + +def test_toolset_with_incompatible_env_raises_type_error(): + """Creating CodeExecutionToolset with a toolset and env without run_python_with_functions raises TypeError.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment + + class _RunPythonOnlyEnv(ExecutionEnvironment): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'run_python'}) + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + with pytest.raises(TypeError, match='does not support external functions'): + CodeExecutionToolset(_RunPythonOnlyEnv(), toolset=ts) + + +async def test_positional_args_rejected(): + """Positional arguments in a function call raise CodeRuntimeError -> ModelRetry.""" + from pydantic_harness.toolsets.code_execution._abstract import FunctionCall + + class _CallbackEnv(StubEnvironment): + async def run_python_with_functions( + self, code: str, *, function_callback: Any, functions: Any = None, referenced_types: Any = None + ) -> Any: + call = FunctionCall(call_id='1', function_name='_add', args=(1, 2), kwargs={}) + return await function_callback(call) + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(_CallbackEnv(), toolset=ts) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + + from pydantic_ai.exceptions import ModelRetry + + with pytest.raises(ModelRetry, match='Runtime error'): + await cm.call_tool('run_code', {'code': '_add(1, 2)'}, ctx, tools['run_code']) + + +async def test_validation_error_becomes_runtime_error(): + """ValidationError from invalid kwargs is surfaced as ModelRetry.""" + from pydantic_harness.toolsets.code_execution._abstract import FunctionCall + + class _BadKwargsEnv(StubEnvironment): + async def run_python_with_functions( + self, code: str, *, function_callback: Any, functions: Any = None, referenced_types: Any = None + ) -> Any: + call = FunctionCall(call_id='1', function_name='_add', args=(), kwargs={'x': 'not_int', 'y': 'not_int'}) + return await function_callback(call) + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(_BadKwargsEnv(), toolset=ts) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + + from pydantic_ai.exceptions import ModelRetry + + with pytest.raises(ModelRetry, match='Runtime error'): + await cm.call_tool('run_code', {'code': '_add(x="a", y="b")'}, ctx, tools['run_code']) + + +# --- Restart tests --- + + +async def test_restart_only_resets_environment(): + """restart=True without code resets the environment and returns confirmation.""" + reset_called = False + + class _TrackResetEnv(StubEnvironment): + def reset(self) -> None: + nonlocal reset_called + reset_called = True + + cm = CodeExecutionToolset(_TrackResetEnv()) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + result = await cm.call_tool('run_code', {'restart': True}, ctx, tools['run_code']) + assert result == 'Session restarted successfully.' + assert reset_called + + +async def test_restart_with_code_resets_then_executes(): + """restart=True with code resets the environment, then runs the code.""" + events: list[str] = [] + + class _TrackEnv(StubEnvironment): + def reset(self) -> None: + events.append('reset') + + async def run_python_with_functions( + self, code: str, *, function_callback: Any, functions: Any = None, referenced_types: Any = None + ) -> Any: + events.append(f'exec:{code}') + return 42 + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(_TrackEnv(), toolset=ts) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + result = await cm.call_tool('run_code', {'code': '1 + 1', 'restart': True}, ctx, tools['run_code']) + assert result == 42 + assert events == ['reset', 'exec:1 + 1'] + + +async def test_no_code_no_restart_raises_model_retry(): + """Empty args (no code, no restart) raises ModelRetry.""" + from pydantic_ai.exceptions import ModelRetry + + cm = CodeExecutionToolset(StubEnvironment()) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + with pytest.raises(ModelRetry, match='Either `code` or `restart: true`'): + await cm.call_tool('run_code', {}, ctx, tools['run_code']) + + +async def test_restart_calls_type_check(): + """restart=True with code calls type_check before execution.""" + events: list[str] = [] + + class _TypeCheckEnv(StubEnvironment): + def reset(self) -> None: + events.append('reset') + + def type_check(self, code: str, *, signatures: Any = None, referenced_types: Any = None) -> None: + events.append(f'type_check:{code}') + + async def run_python_with_functions( + self, code: str, *, function_callback: Any, functions: Any = None, referenced_types: Any = None + ) -> Any: + events.append(f'exec:{code}') + return 'ok' + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(_TypeCheckEnv(), toolset=ts) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + await cm.call_tool('run_code', {'code': 'x = 1', 'restart': True}, ctx, tools['run_code']) + assert events == ['reset', 'type_check:x = 1', 'exec:x = 1'] + + +async def test_no_restart_skips_type_check(): + """Without restart, type_check is not called (accumulated state makes it unsound).""" + type_check_called = False + + class _TypeCheckEnv(StubEnvironment): + def type_check(self, code: str, *, signatures: Any = None, referenced_types: Any = None) -> None: + nonlocal type_check_called + type_check_called = True + + async def run_python_with_functions( + self, code: str, *, function_callback: Any, functions: Any = None, referenced_types: Any = None + ) -> Any: + return 'ok' + + ts: FunctionToolset[None] = FunctionToolset() + ts.add_function(_add, takes_ctx=False) + cm = CodeExecutionToolset(_TypeCheckEnv(), toolset=ts) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + await cm.call_tool('run_code', {'code': 'x = 1'}, ctx, tools['run_code']) + assert not type_check_called + + +async def test_restart_type_error_raises_model_retry(): + """Type error during restart type check surfaces as ModelRetry.""" + from pydantic_ai.exceptions import ModelRetry + from pydantic_harness.toolsets.code_execution._abstract import CodeTypingError + + class _TypeErrorEnv(StubEnvironment): + def reset(self) -> None: + pass + + def type_check(self, code: str, *, signatures: Any = None, referenced_types: Any = None) -> None: + raise CodeTypingError('x is not an int') + + cm = CodeExecutionToolset(_TypeErrorEnv()) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + with pytest.raises(ModelRetry, match='Type error in generated code'): + await cm.call_tool('run_code', {'code': 'bad code', 'restart': True}, ctx, tools['run_code']) + + +async def test_restart_syntax_error_raises_model_retry(): + """Syntax error during restart type check surfaces as ModelRetry.""" + from pydantic_ai.exceptions import ModelRetry + from pydantic_harness.toolsets.code_execution._abstract import CodeSyntaxError + + class _SyntaxErrorEnv(StubEnvironment): + def reset(self) -> None: + pass + + def type_check(self, code: str, *, signatures: Any = None, referenced_types: Any = None) -> None: + raise CodeSyntaxError('unexpected EOF') + + cm = CodeExecutionToolset(StubEnvironment()) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + with pytest.raises(ModelRetry, match='Syntax error in generated code'): + await cm.call_tool('run_code', {'code': 'def', 'restart': True}, ctx, tools['run_code']) + + +async def test_description_includes_session_management(): + """Tool description explains restart/session management.""" + cm = CodeExecutionToolset(StubEnvironment()) + tools = await cm.get_tools(build_run_context()) + description = tools['run_code'].tool_def.description or '' + assert 'REPL session' in description + assert 'restart' in description + assert 'clean slate' in description + + +async def test_run_python_fallback_without_functions_capability(): + """CodeExecutionToolset falls back to run_python when env lacks run_python_with_functions.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment + + class _RunPythonOnlyEnv(ExecutionEnvironment): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'run_python'}) + + async def run_python(self, code: str) -> Any: + return 'ran_python' + + cm = CodeExecutionToolset(_RunPythonOnlyEnv()) + ctx = build_run_context() + tools = await cm.get_tools(ctx) + result = await cm.call_tool('run_code', {'code': 'print("hi")'}, ctx, tools['run_code']) + assert result == 'ran_python' diff --git a/tests/code_execution/test_monty.py b/tests/code_execution/test_monty.py new file mode 100644 index 0000000..f8063d7 --- /dev/null +++ b/tests/code_execution/test_monty.py @@ -0,0 +1,692 @@ +"""Tests for Monty runtime type checking integration.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel +from typing_extensions import NotRequired, TypedDict + +try: + from pydantic_monty import Monty + + from pydantic_harness.environments.monty import MontyEnvironment +except ImportError: # pragma: lax no cover + pytest.skip('pydantic-monty is not installed', allow_module_level=True) + +from pydantic_ai._tool_manager import _parallel_execution_mode_ctx_var # pyright: ignore[reportPrivateUsage] +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.tools import Tool + +from .conftest import build_code_execution_toolset, run_code_with_tools + +pytestmark = pytest.mark.anyio + + +def add(*, x: int, y: int) -> int: + """Add two integers.""" + return x + y + + +async def test_type_error_caught_at_validation_not_type_check(): + """With type checking disabled for REPL mode, type errors are caught at Pydantic validation time. + + The tool callback validates arguments via Pydantic, so passing strings where ints are + expected raises a ValidationError wrapped as a runtime CodeRuntimeError → ModelRetry. + """ + with pytest.raises(ModelRetry, match='Runtime error in generated code'): + await run_code_with_tools('await add(x="hello", y="world")', MontyEnvironment(), (add, False)) + + +# TODO: Remove xfail when PR #4755 lands — schema_to_signature returns `-> Any` instead of `-> int` +@pytest.mark.xfail(reason='PR #4755: schema_to_signature produces less precise return types than function_to_signature') +async def test_generated_signatures_are_valid_python(): + """Generated signatures must be valid Python that Monty can parse and type check.""" + _, tools = await build_code_execution_toolset(MontyEnvironment(), (add, False)) + + tool = tools['run_code'] + env = MontyEnvironment() + prefix = env._build_type_check_prefix(tool.signatures, tool.referenced_types) # pyright: ignore[reportPrivateUsage] + + # `...` and `pass` are not valid for Monty/ty type checking — ty is intentionally + # stricter than pyright here. See https://github.com/astral-sh/ty/issues/1922 + assert prefix == snapshot('''\ +import asyncio +from typing import Any, TypedDict, NotRequired, Literal + +async def add(*, x: int, y: int) -> int: + """Add two integers.""" + raise NotImplementedError()\ +''') + # Verify Monty can parse and type check code using this prefix + m = Monty('add(x=1, y=2)', type_check=True, type_check_stubs=prefix) # Should not raise + + +async def test_signatures_use_ellipsis_monty_converts_for_type_check(): + """Signatures use '...' body; Monty converts to 'raise NotImplementedError()' for type checking.""" + _code_execution, tools = await build_code_execution_toolset(MontyEnvironment(), (add, False)) + + tool = tools['run_code'] + + # LLM-facing description should have '...' + description = tool.tool_def.description or '' + assert '...' in description + assert 'raise NotImplementedError()' not in description + + # But when Monty builds the type-check prefix, it converts to 'raise NotImplementedError()' + env = MontyEnvironment() + prefix = env._build_type_check_prefix(tool.signatures, tool.referenced_types) # pyright: ignore[reportPrivateUsage] + assert 'raise NotImplementedError()' in prefix + assert ' ...' not in prefix + + +# --- Types and tools for test_full_description_snapshot --- + + +class _Tag(TypedDict): + """A key-value tag.""" + + key: str + value: str + + +class _Resource(TypedDict): + name: str + tags: list[_Tag] + metadata: NotRequired[dict[str, str]] + """Extra metadata.""" + parent_id: int | None + + +def _find_resources(*, query: str, limit: int = 10) -> list[_Resource]: + """Find resources matching a query.""" + return [] # pragma: no cover + + +# TODO (Douwe): this doesn't actually work :) +# Both have __name__ == 'Item' to test dedup when different tools share a type name +class _SearchItem(TypedDict): + name: str + price: float + + +class _LookupItem(TypedDict): + id: int + category: str + + +def _search_items(*, query: str) -> list[_SearchItem]: + """Search for items by name.""" + return [] # pragma: no cover + + +def _get_item(*, item_id: int) -> list[_LookupItem]: + """Get items by ID.""" + return [] # pragma: no cover + + +def _tag_resource(*, resource_name: str, tag: _Tag) -> bool: + """Add a tag to a resource.""" + return True # pragma: no cover + + +# TODO: Remove xfail when PR #4755 lands — schema_to_signature produces different description format +@pytest.mark.xfail(reason='PR #4755: schema_to_signature produces different signatures than function_to_signature') +async def test_full_description_snapshot(): + """Snapshot the full run_code description with shared types, conflicts, and nesting.""" + _, tools = await build_code_execution_toolset( + MontyEnvironment(), + (_find_resources, False), + (_search_items, False), + (_get_item, False), + (_tag_resource, False), + ) + description = tools['run_code'].tool_def.description + assert description == snapshot('''\ + +Use this tool to run Python code that can call other tools as functions. + +You can use it to: +- filter tool return data to save context, +- perform complex operations that would take many model calls using standard tool calling, or +- pass the result of one tool to another without it entering your context window. + +Execution model: +- This is a REPL session — state persists across calls. Variables, functions, and imports defined in previous calls are available in subsequent calls. You can split work across multiple calls and build on earlier results. +- If a previous call failed, the state from earlier *successful* calls is still intact — you only need to fix the failed snippet, not rewrite everything from scratch. +- You can create new functions for convenience. +- This tool is for calling and chaining tools programmatically — don't use it just to format or print your final analysis. Write your report as regular text in your response. + +Session management: +- Set `restart: true` to clear all accumulated state and start a fresh session. You can combine it with `code` to reset and run in one call, or use it alone to just reset. +- Use restart when your session state is corrupted or you want a completely clean slate. + + + +The runtime uses a restricted Python subset: +- you cannot use the standard library except builtin functions and the following modules: `sys`, `typing`, `asyncio`, `math`, `re` +- this means `collections`, `json`, `datetime`, `itertools`, `functools`, etc. are NOT available — use plain dicts, lists, and builtins instead +- you cannot use third party libraries +- you cannot define classes +- chained subscript assignment like `x[a][b] = val` is NOT supported — read into a local variable, modify it, then assign back: `inner = x[a]; inner[b] = val; x[a] = inner` + +State persists across calls — variables and functions defined in previous calls are available in subsequent calls. + +The last expression evaluated is the return value. + +Parallelism: use `asyncio.gather` to fire multiple calls at the same time instead of awaiting each one sequentially: + + # GOOD — parallel (all calls fire at once): + results = await asyncio.gather( + get_data(id=1), + get_data(id=2), + get_data(id=3), + ) + + # BAD — sequential (each call waits before the next starts): + r1 = await get_data(id=1) + r2 = await get_data(id=2) + r3 = await get_data(id=3) + + +```python + +# Available types: + +class _Tag(TypedDict): + """A key-value tag.""" + key: str + value: str + +class _Resource(TypedDict): + name: str + tags: list[_Tag] + metadata: NotRequired[dict[str, str]] + parent_id: int | None + +class _SearchItem(TypedDict): + name: str + price: float + +class _LookupItem(TypedDict): + id: int + category: str + +# Available functions: + +async def _find_resources(*, query: str, limit: int = 10) -> list[_Resource]: + """Find resources matching a query.""" + ... + +async def _search_items(*, query: str) -> list[_SearchItem]: + """Search for items by name.""" + ... + +async def _get_item(*, item_id: int) -> list[_LookupItem]: + """Get items by ID.""" + ... + +async def _tag_resource(*, resource_name: str, tag: _Tag) -> bool: + """Add a tag to a resource.""" + ... + +```\ +''') + + +async def test_monty_runtime_error_raises_model_retry(): + """MontyRuntimeError during execution is surfaced as ModelRetry.""" + with pytest.raises(ModelRetry, match='Runtime error in generated code'): + await run_code_with_tools('1 / 0', MontyEnvironment(), (add, False)) + + +async def test_monty_syntax_error_message(): + """Monty syntax errors include a descriptive message for the LLM.""" + with pytest.raises(ModelRetry) as exc_info: + await run_code_with_tools('def while invalid syntax', MontyEnvironment()) + + assert str(exc_info.value) == snapshot("""\ +Syntax error in generated code: +Expected an identifier, but found a keyword `while` that cannot be used here at byte range 4..9\ +""") + + +class UserModel(BaseModel): + name: str + age: int + + +async def test_monty_normalizes_tool_results_to_json_compatible(): + """Tool results fed to Monty should be JSON-compatible (dicts, not BaseModels). + + Without normalization, Monty receives the raw Python object (e.g. a Pydantic + BaseModel), but driver-based runtimes serialize results over JSON and would + receive a plain dict. This inconsistency means code that works on one + runtime could break on another. + + The fix: normalize all results to JSON-compatible form (via + tool_return_ta.dump_python(mode='json')) before feeding them to Monty, + matching what driver-based runtimes already do. + + This test exposes the issue by passing a tool result to a second tool that + observes the actual Python type on the host side. + """ + + def get_user(id: int) -> UserModel: + """Get a user by ID.""" + return UserModel(name='Alice', age=30) + + received_types: list[str] = [] + + def inspect_type(data: Any) -> str: + """Record the Python type of the received data.""" + received_types.append(type(data).__name__) + return type(data).__name__ + + code = 'user = await get_user(id=1)\nawait inspect_type(data=user)' + + result = await run_code_with_tools( + code, + MontyEnvironment(), + (get_user, False), + (inspect_type, False), + ) + + # After normalization, the second tool should receive a dict, not a UserModel. + # This guarantees consistent behavior across runtimes. + assert result == 'dict' + assert received_types == ['dict'] + + +async def test_build_type_check_prefix_empty_lists(): + """Empty signatures/types produces just the typing import line.""" + env = MontyEnvironment() + prefix = env._build_type_check_prefix([], []) # pyright: ignore[reportPrivateUsage] + assert prefix == 'import asyncio\nfrom typing import Any, TypedDict, NotRequired, Literal' + + +# --- Sequential tool tests --- + + +def sync_add(*, x: int, y: int) -> int: + """Add two integers synchronously.""" + return x + y + + +async def test_sequential_tool_renders_as_def(): + """A sequential tool renders as `def` (not `async def`) in description and type-check prefix.""" + _, tools = await build_code_execution_toolset( + MontyEnvironment(), + Tool(sync_add, takes_ctx=False, sequential=True), + ) + tool = tools['run_code'] + + # Description shows `def`, not `async def` + description = tool.tool_def.description or '' + assert 'def sync_add(' in description + assert 'async def sync_add(' not in description + + # Type-check prefix also uses `def` + env = MontyEnvironment() + prefix = env._build_type_check_prefix(tool.signatures, tool.referenced_types) # pyright: ignore[reportPrivateUsage] + assert 'def sync_add(' in prefix + assert 'async def sync_add(' not in prefix + + +async def test_sequential_tool_execution(): + """Sequential tool is called as `result = my_tool(x=1)` (no `await`).""" + result = await run_code_with_tools( + 'sync_add(x=3, y=4)', + MontyEnvironment(), + Tool(sync_add, takes_ctx=False, sequential=True), + ) + assert result == 7 + + +async def test_mixed_sync_async_execution(): + """Both async (fire-then-await) and sync tools work in the same code block.""" + code = 'f = add(x=1, y=2)\nsum_result = sync_add(x=10, y=20)\nawaited = await f\n[awaited, sum_result]' + result = await run_code_with_tools( + code, + MontyEnvironment(), + (add, False), + Tool(sync_add, takes_ctx=False, sequential=True), + ) + assert result == [3, 30] + + +async def test_sequential_drain_behavior(): + """Firing an async tool then calling a sequential tool drains the async task first.""" + call_order: list[str] = [] + + def fire_tool(*, name: str) -> str: + """An async tool.""" + call_order.append(f'async:{name}') + return f'async:{name}' + + def seq_tool(*, name: str) -> str: + """A sequential tool.""" + call_order.append(f'seq:{name}') + return f'seq:{name}' + + code = 'f = fire_tool(name="first")\nresult = seq_tool(name="second")\nawaited = await f\n[awaited, result]' + result = await run_code_with_tools( + code, + MontyEnvironment(), + (fire_tool, False), + Tool(seq_tool, takes_ctx=False, sequential=True), + ) + assert result == ['async:first', 'seq:second'] + # The async tool must complete before the sequential tool starts + assert call_order == ['async:first', 'seq:second'] + + +async def test_await_on_sync_tool_is_type_error(): + """`await sync_tool()` raises ModelRetry (runtime error since type checking is disabled for REPL mode).""" + with pytest.raises(ModelRetry, match='Runtime error in generated code'): + await run_code_with_tools( + 'await sync_add(x=1, y=2)', + MontyEnvironment(), + Tool(sync_add, takes_ctx=False, sequential=True), + ) + + +async def test_global_sequential_mode(): + """Setting _parallel_execution_mode_ctx_var to 'sequential' makes all tools render as `def`.""" + token = _parallel_execution_mode_ctx_var.set('sequential') + try: + _, tools = await build_code_execution_toolset( + MontyEnvironment(), + (add, False), + ) + tool = tools['run_code'] + description = tool.tool_def.description or '' + assert 'def add(' in description + assert 'async def add(' not in description + finally: + _parallel_execution_mode_ctx_var.reset(token) + + +async def test_global_parallel_ordered_events_mode(): + """Setting _parallel_execution_mode_ctx_var to 'parallel_ordered_events' makes all tools render as `def`.""" + token = _parallel_execution_mode_ctx_var.set('parallel_ordered_events') + try: + _, tools = await build_code_execution_toolset( + MontyEnvironment(), + (add, False), + ) + tool = tools['run_code'] + description = tool.tool_def.description or '' + assert 'def add(' in description + assert 'async def add(' not in description + finally: + _parallel_execution_mode_ctx_var.reset(token) + + +async def test_description_no_all_functions_are_async(): + """The prompt no longer says 'All functions are async'.""" + _, tools = await build_code_execution_toolset(MontyEnvironment(), (add, False)) + description = tools['run_code'].tool_def.description or '' + assert 'All functions are async' not in description + + +async def test_pending_tasks_cancelled_on_runtime_error(): + """Async tasks fired before a runtime error are cancelled in the finally block.""" + code = 'f = add(x=1, y=2)\n1 / 0' + with pytest.raises(ModelRetry, match='Runtime error in generated code'): + await run_code_with_tools(code, MontyEnvironment(), (add, False)) + + +# --- Direct run_python tests --- + + +async def test_run_python_success(): + """run_python returns the last expression.""" + env = MontyEnvironment() + result = await env.run_python('"hello"') + assert result == 'hello' + + +async def test_run_python_syntax_error(): + """run_python raises CodeSyntaxError for invalid syntax.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeSyntaxError + + env = MontyEnvironment() + with pytest.raises(CodeSyntaxError): + await env.run_python('def while') + + +async def test_run_python_type_annotation_not_enforced_at_runtime(): + """With type checking disabled for REPL mode, type annotations are not enforced.""" + env = MontyEnvironment() + # x: int = "hello" is accepted — Monty's runtime doesn't enforce annotations + result = await env.run_python('x: int = "hello"\nx') + assert result == 'hello' + + +async def test_run_python_runtime_error(): + """run_python raises CodeRuntimeError for runtime errors.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeRuntimeError + + env = MontyEnvironment() + with pytest.raises(CodeRuntimeError): + await env.run_python('1 / 0') + + +async def test_run_python_with_functions_default_params(): + """run_python_with_functions works with default functions/referenced_types.""" + from unittest.mock import AsyncMock + + env = MontyEnvironment() + result = await env.run_python_with_functions( + '"hello"', + function_callback=AsyncMock(), + ) + assert result == 'hello' + + +# --- Reset and type_check tests --- + + +async def test_reset_clears_repl_state(): + """reset() discards REPL state so prior variables are no longer available.""" + env = MontyEnvironment() + await env.run_python('x = 42') + result = await env.run_python('x') + assert result == 42 + + env.reset() + + from pydantic_harness.toolsets.code_execution._abstract import CodeRuntimeError + + with pytest.raises(CodeRuntimeError, match='NameError'): + await env.run_python('x') + + +async def test_reset_allows_fresh_start(): + """After reset(), new code executes in a clean environment.""" + env = MontyEnvironment() + await env.run_python('x = "old"') + env.reset() + await env.run_python('x = "new"') + result = await env.run_python('x') + assert result == 'new' + + +async def test_type_check_catches_type_error(): + """type_check() raises CodeTypingError for type mismatches.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeTypingError + + env = MontyEnvironment() + with pytest.raises(CodeTypingError): + env.type_check('x: int = "hello"') + + +async def test_type_check_catches_syntax_error(): + """type_check() raises CodeSyntaxError for invalid syntax.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeSyntaxError + + env = MontyEnvironment() + with pytest.raises(CodeSyntaxError): + env.type_check('def while') + + +async def test_type_check_with_function_stubs(): + """type_check() validates calls against provided function signatures.""" + env = MontyEnvironment() + _, tools = await build_code_execution_toolset(env, (add, False)) + tool = tools['run_code'] + + # Valid call — should not raise + env.type_check('await add(x=1, y=2)', signatures=tool.signatures, referenced_types=tool.referenced_types) + + +async def test_type_check_valid_code_passes(): + """type_check() does not raise for valid code.""" + env = MontyEnvironment() + env.type_check('x: int = 42') # Should not raise + + +async def test_repl_state_persists_across_calls(): + """REPL state persists — variables survive between calls.""" + env = MontyEnvironment() + await env.run_python('count = 0') + await env.run_python('count = count + 1') + await env.run_python('count = count + 1') + result = await env.run_python('count') + assert result == 2 + + +# --- Print output tests --- + + +async def test_print_only_returns_string(): + """print() with no expression result returns the printed text.""" + env = MontyEnvironment() + result = await env.run_python('print("hello")') + assert result == 'hello' + + +async def test_print_multiple_lines(): + """Multiple print() calls are concatenated.""" + env = MontyEnvironment() + result = await env.run_python('print("line 1")\nprint("line 2")') + assert result == 'line 1\nline 2' + + +async def test_output_only_preserves_structure(): + """Expression result without prints is returned as-is (structured).""" + env = MontyEnvironment() + result = await env.run_python('[1, 2, 3]') + assert result == [1, 2, 3] + + +async def test_print_and_output_returns_dict(): + """print() combined with an expression result returns a dict with both.""" + env = MontyEnvironment() + result = await env.run_python('print("debug")\n[1, 2, 3]') + assert result == {'stdout': 'debug', 'result': [1, 2, 3]} + + +async def test_print_and_dict_output_returns_dict(): + """print() combined with a dict expression preserves the dict structure.""" + env = MontyEnvironment() + result = await env.run_python('print("info")\n{"key": "value"}') + assert result == {'stdout': 'info', 'result': {'key': 'value'}} + + +async def test_print_and_none_output_returns_string(): + """print() with None expression result returns just the printed text.""" + env = MontyEnvironment() + result = await env.run_python('print("hello")\nNone') + assert result == 'hello' + + +async def test_prints_included_in_runtime_error(): + """Print output before a runtime error is included in the error message.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeRuntimeError + + env = MontyEnvironment() + with pytest.raises(CodeRuntimeError) as exc_info: + await env.run_python('print("debug info")\n1 / 0') + assert 'debug info' in exc_info.value.message + assert 'ZeroDivisionError' in exc_info.value.message + assert '[stdout before error]' in exc_info.value.message + + +async def test_prints_not_in_error_when_no_prints(): + """Error messages without prior prints don't have the stdout wrapper.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeRuntimeError + + env = MontyEnvironment() + with pytest.raises(CodeRuntimeError) as exc_info: + await env.run_python('1 / 0') + assert 'stdout before error' not in exc_info.value.message + + +async def test_print_with_function_calls(): + """Print output is captured alongside external function calls.""" + env = MontyEnvironment() + result = await run_code_with_tools( + 'print("before call")\nr = await add(x=1, y=2)\nprint("after call")\nr', + env, + (add, False), + ) + assert result == {'stdout': 'before call\nafter call', 'result': 3} + + +async def test_print_after_function_call(): + """Prints after an external function call returns are still captured.""" + env = MontyEnvironment() + result = await run_code_with_tools( + 'r = await add(x=10, y=20)\nprint(r)', + env, + (add, False), + ) + assert result == '30' + + +async def test_prints_included_in_runtime_error_with_functions(): + """Print output before a runtime error with functions is included in the error message.""" + env = MontyEnvironment() + with pytest.raises(ModelRetry, match='debug') as exc_info: + await run_code_with_tools( + 'print("debug")\n1 / 0', + env, + (add, False), + ) + assert 'debug' in str(exc_info.value) + + +async def test_print_and_falsy_output_returns_dict(): + """print() combined with a falsy (but not None) result returns a dict.""" + env = MontyEnvironment() + result = await env.run_python('print("debug")\n0') + assert result == {'stdout': 'debug', 'result': 0} + + result = await env.run_python('print("debug")\nFalse') + assert result == {'stdout': 'debug', 'result': False} + + result = await env.run_python('print("debug")\n[]') + assert result == {'stdout': 'debug', 'result': []} + + +async def test_print_and_string_output_returns_dict(): + """print() combined with a string expression result returns a dict, not a flat string.""" + env = MontyEnvironment() + result = await env.run_python('print("debug")\n"hello"') + assert result == {'stdout': 'debug', 'result': 'hello'} + + +async def test_prints_included_in_timeout_error(): + """Print output before a timeout is included in the error message.""" + from pydantic_harness.toolsets.code_execution._abstract import CodeExecutionTimeout + + env = MontyEnvironment() + env.execution_timeout = 1.0 + with pytest.raises(CodeExecutionTimeout) as exc_info: + await env.run_python('print("started")\nwhile True: pass') + assert 'started' in exc_info.value.message + assert 'timed out' in exc_info.value.message diff --git a/tests/code_execution/test_sanitization.py b/tests/code_execution/test_sanitization.py new file mode 100644 index 0000000..689b422 --- /dev/null +++ b/tests/code_execution/test_sanitization.py @@ -0,0 +1,52 @@ +"""Tests for tool name sanitization edge cases.""" + +from __future__ import annotations + +import pytest + +from pydantic_harness.toolsets.code_execution import _sanitize_tool_name # pyright: ignore[reportPrivateUsage] + + +@pytest.mark.parametrize( + ('input_name', 'expected'), + [ + # Basic separators + ('search-records', 'search_records'), + ('get.user.data', 'get_user_data'), + ('hello world', 'hello_world'), + # camelCase / PascalCase -> snake_case + ('getUserData', 'get_user_data'), + ('XMLParser', 'xml_parser'), + # Leading digit + ('3d_model', '_3d_model'), + # Unicode stripped -> fallback + ('获取数据', 'tool'), + # All-separator -> underscore (separators become _, then lowercased) + ('---', '_'), + ('...', '_'), + # Python keywords + ('class', 'class_'), + ('return', 'return_'), + ('import', 'import_'), + # Keyword produced by camelCase conversion + ('returnValue', 'return_value'), + # Mixed separators + ('get-user.data name', 'get_user_data_name'), + # Already valid + ('valid_name', 'valid_name'), + # Single character + ('x', 'x'), + ], + ids=lambda v: repr(v), +) +def test_sanitize_tool_name(input_name: str, expected: str) -> None: + assert _sanitize_tool_name(input_name) == expected + + +def test_sanitize_collision_handling() -> None: + """Two distinct names that sanitize to the same result are handled by the caller (CodeExecutionToolset). + + _sanitize_tool_name itself is stateless -- it just normalizes a single name. + This test documents that identical outputs are possible. + """ + assert _sanitize_tool_name('get-data') == _sanitize_tool_name('get.data') == 'get_data' diff --git a/tests/code_execution/test_stdio_driver.py b/tests/code_execution/test_stdio_driver.py new file mode 100644 index 0000000..b733f59 --- /dev/null +++ b/tests/code_execution/test_stdio_driver.py @@ -0,0 +1,736 @@ +"""Tests for the stdio driver protocol. + +Tests _driver.py as a local subprocess (no Docker, no runtime class needed). +Communicates directly via NDJSON over stdin/stdout pipes. +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from pathlib import Path + +import pytest + +from pydantic_harness.toolsets.code_execution._driver import ( + _build_proxy, # pyright: ignore[reportPrivateUsage] + _compile_code, # pyright: ignore[reportPrivateUsage] + _execute, # pyright: ignore[reportPrivateUsage] + _stdin_reader, # pyright: ignore[reportPrivateUsage] + _transform_last_expr, # pyright: ignore[reportPrivateUsage] +) + +DRIVER_PATH = ( + Path(__file__).parents[2] / 'pydantic_harness' / 'toolsets' / 'code_execution' / '_driver.py' +) + +pytestmark = pytest.mark.anyio + + +async def start_driver( + code: str, + functions: list[str] | None = None, + result_cache: dict[str, object] | None = None, +) -> asyncio.subprocess.Process: + """Start the driver subprocess and send the init message.""" + proc = await asyncio.create_subprocess_exec( + sys.executable, + '-u', + str(DRIVER_PATH), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + init_msg: dict[str, object] = { + 'type': 'init', + 'code': code, + 'functions': functions or [], + } + if result_cache is not None: + init_msg['result_cache'] = result_cache + assert proc.stdin is not None + proc.stdin.write(json.dumps(init_msg).encode() + b'\n') + await proc.stdin.drain() + return proc + + +async def read_msg(proc: asyncio.subprocess.Process) -> dict[str, object]: + """Read a single NDJSON message from the driver's stdout.""" + assert proc.stdout is not None + line = await asyncio.wait_for(proc.stdout.readline(), timeout=10.0) + assert line, 'Driver produced no output (EOF)' + return json.loads(line) + + +async def send_msg(proc: asyncio.subprocess.Process, msg: dict[str, object]) -> None: + """Send a single NDJSON message to the driver's stdin.""" + assert proc.stdin is not None + proc.stdin.write(json.dumps(msg).encode() + b'\n') + await proc.stdin.drain() + + +async def test_simple_tool_call(): + """Driver sends call message, receives result, returns complete.""" + proc = await start_driver('await add(x=1, y=2)', functions=['add']) + + # Read call message + msg = await read_msg(proc) + assert msg['type'] == 'call' + assert msg['function'] == 'add' + assert msg['kwargs'] == {'x': 1, 'y': 2} + assert msg['id'] == 1 + + # Read calls_ready fence (batch boundary) + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + # Send result + await send_msg(proc, {'type': 'result', 'id': 1, 'result': 3}) + + # Read completion + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 3 + + proc.kill() + await proc.wait() + + +async def test_parallel_tool_calls(): + """Fire-then-await pattern sends multiple calls before awaiting.""" + code = 'f1 = add(x=1, y=2)\nf2 = add(x=3, y=4)\nr1 = await f1\nr2 = await f2\nr1 + r2' + proc = await start_driver(code, functions=['add']) + + # Read 2 call messages (both sent before any result) + msg1 = await read_msg(proc) + msg2 = await read_msg(proc) + assert msg1['type'] == 'call' + assert msg1['id'] == 1 + assert msg2['type'] == 'call' + assert msg2['id'] == 2 + + # Read calls_ready fence (batch boundary after last synchronous call) + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + # Send both results + await send_msg(proc, {'type': 'result', 'id': 1, 'result': 3}) + await send_msg(proc, {'type': 'result', 'id': 2, 'result': 7}) + + # Read completion + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 10 + + proc.kill() + await proc.wait() + + +async def test_result_cache_all_hits(): + """With full result cache, no call messages emitted.""" + proc = await start_driver( + 'r = await add(x=1, y=2)\nr', + functions=['add'], + result_cache={'1': 3}, + ) + + # Should go straight to complete (no call message) + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 3 + + proc.kill() + await proc.wait() + + +async def test_result_cache_partial_hit(): + """Partial cache: cached calls skip RPC, uncached go through.""" + code = 'f1 = add(x=1, y=2)\nf2 = add(x=3, y=4)\nr1 = await f1\nr2 = await f2\nr1 + r2' + proc = await start_driver(code, functions=['add'], result_cache={'1': 3}) + + # Only 1 call message (call 2, since call 1 is cached) + msg = await read_msg(proc) + assert msg['type'] == 'call' + assert msg['id'] == 2 + + # Read calls_ready fence + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + # Send result for call 2 + await send_msg(proc, {'type': 'result', 'id': 2, 'result': 7}) + + # Read completion + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 10 + + proc.kill() + await proc.wait() + + +async def test_syntax_error(): + """Syntax errors in code produce error message with type=syntax.""" + proc = await start_driver('def while invalid') + + msg = await read_msg(proc) + assert msg['type'] == 'error' + assert msg['error_type'] == 'syntax' + + proc.kill() + await proc.wait() + + +async def test_runtime_error(): + """Runtime exceptions produce error message with type=runtime.""" + proc = await start_driver('1 / 0') + + msg = await read_msg(proc) + assert msg['type'] == 'error' + assert msg['error_type'] == 'runtime' + assert 'ZeroDivisionError' in str(msg['error']) + + proc.kill() + await proc.wait() + + +async def test_empty_code(): + """Empty code returns None.""" + proc = await start_driver('') + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] is None + + proc.kill() + await proc.wait() + + +async def test_no_function_calls(): + """Code that doesn't call functions executes locally.""" + proc = await start_driver('1 + 2') + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 3 + + proc.kill() + await proc.wait() + + +async def test_last_expression_multiline_dict(): + """Multiline dict as last expression is correctly returned.""" + proc = await start_driver('{"a": 1, "b": 2}') + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == {'a': 1, 'b': 2} + + proc.kill() + await proc.wait() + + +async def test_print_goes_to_stderr(): + """print() in code goes to stderr, not stdout (protocol protected).""" + proc = await start_driver('print("hello")\n42') + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 42 + + # Check stderr contains the print output + assert proc.stderr is not None + # Give the process a moment to finish writing stderr + proc.kill() + await proc.wait() + stderr_output = await proc.stderr.read() + assert b'hello' in stderr_output + + +async def test_tool_error_propagated(): + """Host error message surfaces as exception in code.""" + code = 'result = "no error"\ntry:\n await bad_tool()\nexcept Exception as e:\n result = str(e)\nresult' + proc = await start_driver(code, functions=['bad_tool']) + + # Read call message + msg = await read_msg(proc) + assert msg['type'] == 'call' + assert msg['function'] == 'bad_tool' + call_id = msg['id'] + + # Read calls_ready fence + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + # Send error response + await send_msg(proc, {'type': 'error', 'id': call_id, 'error': 'tool failed'}) + + # Code catches the exception and returns the error string + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert 'tool failed' in str(msg['result']) + + proc.kill() + await proc.wait() + + +async def test_undefined_function(): + """Calling undefined function raises runtime error.""" + proc = await start_driver('await nonexistent()') + + msg = await read_msg(proc) + assert msg['type'] == 'error' + assert msg['error_type'] == 'runtime' + assert 'nonexistent' in str(msg['error']) + + proc.kill() + await proc.wait() + + +async def test_sequential_call_ids(): + """Call IDs are sequential across different proxy functions.""" + code = ( + 'f1 = tool_a(x=1)\n' + 'f2 = tool_b(y=2)\n' + 'f3 = tool_a(x=3)\n' + 'r1 = await f1\n' + 'r2 = await f2\n' + 'r3 = await f3\n' + '[r1, r2, r3]' + ) + proc = await start_driver(code, functions=['tool_a', 'tool_b']) + + # Read 3 call messages — IDs should be 1, 2, 3 + msg1 = await read_msg(proc) + msg2 = await read_msg(proc) + msg3 = await read_msg(proc) + assert msg1['id'] == 1 + assert msg1['function'] == 'tool_a' + assert msg2['id'] == 2 + assert msg2['function'] == 'tool_b' + assert msg3['id'] == 3 + assert msg3['function'] == 'tool_a' + + # Read calls_ready fence (one per batch, after all 3 synchronous calls) + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + # Send all results + await send_msg(proc, {'type': 'result', 'id': 1, 'result': 'a1'}) + await send_msg(proc, {'type': 'result', 'id': 2, 'result': 'b1'}) + await send_msg(proc, {'type': 'result', 'id': 3, 'result': 'a2'}) + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == ['a1', 'b1', 'a2'] + + proc.kill() + await proc.wait() + + +async def test_positional_args(): + """Positional args are sent in the call message's args field.""" + proc = await start_driver('await add(1, 2)', functions=['add']) + + msg = await read_msg(proc) + assert msg['type'] == 'call' + assert msg['function'] == 'add' + assert msg['args'] == [1, 2] + assert msg['kwargs'] == {} + + # Read calls_ready fence + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + await send_msg(proc, {'type': 'result', 'id': 1, 'result': 3}) + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 3 + + proc.kill() + await proc.wait() + + +async def test_mixed_positional_and_keyword_args(): + """Mixed positional and keyword args are split correctly.""" + proc = await start_driver('await add(1, y=2)', functions=['add']) + + msg = await read_msg(proc) + assert msg['type'] == 'call' + assert msg['function'] == 'add' + assert msg['args'] == [1] + assert msg['kwargs'] == {'y': 2} + + # Read calls_ready fence + msg = await read_msg(proc) + assert msg['type'] == 'calls_ready' + + await send_msg(proc, {'type': 'result', 'id': 1, 'result': 3}) + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 3 + + proc.kill() + await proc.wait() + + +async def test_stderr_redirect_all_properties(): + """StderrRedirect stream properties are accessible from user code.""" + code = ( + 'import sys\n' + '{\n' + ' "isatty": sys.stdout.isatty(),\n' + ' "writable": sys.stdout.writable(),\n' + ' "readable": sys.stdout.readable(),\n' + ' "closed": sys.stdout.closed,\n' + ' "fileno_works": isinstance(sys.stdout.fileno(), int),\n' + ' "has_encoding": isinstance(sys.stdout.encoding, str),\n' + '}' + ) + proc = await start_driver(code) + msg = await read_msg(proc) + assert msg == { + 'type': 'complete', + 'result': { + 'isatty': False, + 'writable': True, + 'readable': False, + 'closed': False, + 'fileno_works': True, + 'has_encoding': True, + }, + } + proc.kill() + await proc.wait() + + +async def test_code_ending_with_assignment(): + """Code ending with assignment returns None (not the assigned value).""" + proc = await start_driver('x = 42') + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] is None + proc.kill() + await proc.wait() + + +async def test_runtime_syntax_error_in_eval(): + """SyntaxError raised at runtime (e.g. in eval) is reported as syntax error.""" + proc = await start_driver('eval("def while")') + msg = await read_msg(proc) + assert msg['type'] == 'error' + assert msg['error_type'] == 'syntax' + proc.kill() + await proc.wait() + + +async def _start_raw_driver() -> asyncio.subprocess.Process: + """Start driver without sending init — for testing init error paths.""" + return await asyncio.create_subprocess_exec( + sys.executable, + '-u', + str(DRIVER_PATH), + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + +async def test_main_error_paths(): + """Driver handles missing/invalid/wrong-type init messages.""" + # No init (close stdin) + proc = await _start_raw_driver() + assert proc.stdin is not None + proc.stdin.close() + msg = await read_msg(proc) + assert 'No init message' in str(msg['error']) + proc.kill() + await proc.wait() + + # Invalid JSON + proc = await _start_raw_driver() + assert proc.stdin is not None + proc.stdin.write(b'not json\n') + await proc.stdin.drain() + msg = await read_msg(proc) + assert 'Invalid init message' in str(msg['error']) + proc.kill() + await proc.wait() + + # Wrong type + proc = await _start_raw_driver() + assert proc.stdin is not None + proc.stdin.write(json.dumps({'type': 'wrong'}).encode() + b'\n') + await proc.stdin.drain() + msg = await read_msg(proc) + assert 'Expected init message' in str(msg['error']) + proc.kill() + await proc.wait() + + +# ============================================================================= +# In-process unit tests — import driver functions directly so coverage tracks them +# ============================================================================= + + +def test_transform_last_expr_assignment(): + """Code ending in assignment has no return added.""" + result = _transform_last_expr('x = 42') + assert 'return' not in result + + +def test_compile_code_syntax_error(monkeypatch: pytest.MonkeyPatch): + """_compile_code with invalid code sends error and returns None.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + result = _compile_code('def while', {}) + assert result is None + assert messages[0]['error_type'] == 'syntax' + + +async def test_build_proxy_cache_hit(): + """Proxy with a cache hit returns an already-resolved future.""" + loop = asyncio.get_running_loop() + call_counter: list[int] = [0] + pending: dict[int, asyncio.Future[object]] = {} + cache: dict[str, object] = {'1': 42} + handle: list[asyncio.Handle | None] = [None] + + proxy = _build_proxy('tool', loop, call_counter, pending, cache, handle) + future = proxy() + assert future.done() + assert future.result() == 42 + + +async def test_stdin_reader_eof(): + """EOF on reader exits cleanly.""" + reader = asyncio.StreamReader() + reader.feed_eof() + pending: dict[int, asyncio.Future[object]] = {} + await _stdin_reader(reader, pending) # should return without error + + +async def test_stdin_reader_malformed_json(capsys: pytest.CaptureFixture[str]): + """Malformed JSON produces a warning and continues to EOF.""" + reader = asyncio.StreamReader() + reader.feed_data(b'not json\n') + reader.feed_eof() + pending: dict[int, asyncio.Future[object]] = {} + await _stdin_reader(reader, pending) + assert 'malformed JSON' in capsys.readouterr().err + + +async def test_stdin_reader_error_message(): + """Error message sets exception on the pending future.""" + reader = asyncio.StreamReader() + reader.feed_data(json.dumps({'type': 'error', 'id': 1, 'error': 'fail'}).encode() + b'\n') + reader.feed_eof() + loop = asyncio.get_running_loop() + future: asyncio.Future[object] = loop.create_future() + pending: dict[int, asyncio.Future[object]] = {1: future} + await _stdin_reader(reader, pending) + with pytest.raises(RuntimeError, match='fail'): + future.result() + + +async def test_stdin_reader_result_message(): + """Result message sets the value on the pending future.""" + reader = asyncio.StreamReader() + reader.feed_data(json.dumps({'type': 'result', 'id': 1, 'result': 99}).encode() + b'\n') + reader.feed_eof() + loop = asyncio.get_running_loop() + future: asyncio.Future[object] = loop.create_future() + pending: dict[int, asyncio.Future[object]] = {1: future} + await _stdin_reader(reader, pending) + assert future.result() == 99 + + +async def test_execute_empty_code(monkeypatch: pytest.MonkeyPatch): + """Empty code sends complete with None.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + reader = asyncio.StreamReader() + reader.feed_eof() + await _execute({'code': '', 'functions': []}, reader) + assert messages == [{'type': 'complete', 'result': None}] + + +async def test_execute_syntax_error(monkeypatch: pytest.MonkeyPatch): + """Syntax error in code sends error and returns.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + reader = asyncio.StreamReader() + reader.feed_eof() + await _execute({'code': 'def while', 'functions': []}, reader) + assert messages[0]['error_type'] == 'syntax' + + +async def test_execute_runtime_error(monkeypatch: pytest.MonkeyPatch): + """Runtime error sends error with traceback.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + reader = asyncio.StreamReader() + reader.feed_eof() + await _execute({'code': '1 / 0', 'functions': []}, reader) + assert messages[0]['error_type'] == 'runtime' + assert 'ZeroDivisionError' in str(messages[0]['error']) + + +async def test_execute_runtime_syntax_error(monkeypatch: pytest.MonkeyPatch): + """SyntaxError at runtime (e.g. eval) is reported as syntax.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + reader = asyncio.StreamReader() + reader.feed_eof() + await _execute({'code': 'eval("def while")', 'functions': []}, reader) + assert messages[0]['error_type'] == 'syntax' + + +async def test_build_proxy_normal_call(monkeypatch: pytest.MonkeyPatch): + """Proxy without cache hit creates a pending future and writes call + calls_ready messages.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + + loop = asyncio.get_running_loop() + call_counter: list[int] = [0] + pending: dict[int, asyncio.Future[object]] = {} + cache: dict[str, object] = {} + handle: list[asyncio.Handle | None] = [None] + + proxy = _build_proxy('add', loop, call_counter, pending, cache, handle) + future = proxy(x=1, y=2) + + # Future is pending (not resolved from cache) + assert not future.done() + assert 1 in pending + assert pending[1] is future + + # Call message was written + assert messages[0] == {'type': 'call', 'id': 1, 'function': 'add', 'args': [], 'kwargs': {'x': 1, 'y': 2}} + + # Let the event loop process the calls_ready callback + await asyncio.sleep(0) + assert messages[1] == {'type': 'calls_ready'} + + # Clean up: cancel the handle + assert handle[0] is not None + handle[0].cancel() + + +async def test_build_proxy_batched_calls(monkeypatch: pytest.MonkeyPatch): + """Two proxy calls without awaiting cancel the first calls_ready handle, emitting only one.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + + loop = asyncio.get_running_loop() + call_counter: list[int] = [0] + pending: dict[int, asyncio.Future[object]] = {} + cache: dict[str, object] = {} + handle: list[asyncio.Handle | None] = [None] + + proxy = _build_proxy('add', loop, call_counter, pending, cache, handle) + proxy(x=1, y=2) + proxy(x=3, y=4) + + # Both call messages written, but only one calls_ready after the loop tick + await asyncio.sleep(0) + assert messages == [ + {'type': 'call', 'id': 1, 'function': 'add', 'args': [], 'kwargs': {'x': 1, 'y': 2}}, + {'type': 'call', 'id': 2, 'function': 'add', 'args': [], 'kwargs': {'x': 3, 'y': 4}}, + {'type': 'calls_ready'}, + ] + + assert handle[0] is not None + handle[0].cancel() + + +async def test_stdin_reader_result_no_pending_future(): + """Result for an unknown call ID is silently skipped.""" + reader = asyncio.StreamReader() + reader.feed_data(json.dumps({'type': 'result', 'id': 999, 'result': 42}).encode() + b'\n') + reader.feed_eof() + pending: dict[int, asyncio.Future[object]] = {} + await _stdin_reader(reader, pending) + # No crash — unknown ID is ignored + + +async def test_stdin_reader_unknown_msg_type(): + """Unknown message type is silently skipped.""" + reader = asyncio.StreamReader() + reader.feed_data(json.dumps({'type': 'heartbeat'}).encode() + b'\n') + reader.feed_eof() + pending: dict[int, asyncio.Future[object]] = {} + await _stdin_reader(reader, pending) + # No crash — unknown type is ignored + + +async def test_stdin_reader_error_no_pending_future(): + """Error for an unknown call ID is silently skipped.""" + reader = asyncio.StreamReader() + reader.feed_data(json.dumps({'type': 'error', 'id': 999, 'error': 'fail'}).encode() + b'\n') + reader.feed_eof() + pending: dict[int, asyncio.Future[object]] = {} + await _stdin_reader(reader, pending) + # No crash — unknown ID is ignored + + +async def test_os_write_to_fd1_does_not_corrupt_protocol(): + """os.write(1, ...) goes to stderr after fd-level redirect, protocol is intact.""" + fake_msg = json.dumps({'type': 'complete', 'result': 'SPOOFED'}) + code = f'import os\nos.write(1, {fake_msg.encode()!r} + b"\\n")\n"real result"' + proc = await start_driver(code) + + msg = await read_msg(proc) + assert msg['type'] == 'complete' + assert msg['result'] == 'real result' + + proc.kill() + await proc.wait() + + +async def test_stdin_reader_result_for_done_future(): + """Result for an already-resolved future does not overwrite it.""" + reader = asyncio.StreamReader() + reader.feed_data(json.dumps({'type': 'result', 'id': 1, 'result': 99}).encode() + b'\n') + reader.feed_eof() + loop = asyncio.get_running_loop() + future: asyncio.Future[object] = loop.create_future() + future.set_result('already done') + pending: dict[int, asyncio.Future[object]] = {1: future} + await _stdin_reader(reader, pending) + assert future.result() == 'already done' + + +async def test_execute_with_function_call(monkeypatch: pytest.MonkeyPatch): + """_execute with functions builds proxies, calls them, and completes successfully.""" + messages: list[dict[str, object]] = [] + monkeypatch.setattr('pydantic_harness.toolsets.code_execution._driver._write_msg', messages.append) + + reader = asyncio.StreamReader() + + async def provide_result(): + # Wait for the calls_ready message + while not any(m.get('type') == 'calls_ready' for m in messages): + await asyncio.sleep(0.01) + # Feed result for call id 1 + reader.feed_data(json.dumps({'type': 'result', 'id': 1, 'result': 42}).encode() + b'\n') + reader.feed_eof() + + provider = asyncio.create_task(provide_result()) + + await _execute( + {'code': 'await add(x=1, y=2)', 'functions': ['add']}, + reader, + ) + await provider + + call_msgs = [m for m in messages if m.get('type') == 'call'] + assert len(call_msgs) == 1 + assert call_msgs[0]['function'] == 'add' + + complete_msgs = [m for m in messages if m.get('type') == 'complete'] + assert len(complete_msgs) == 1 + assert complete_msgs[0]['result'] == 42 diff --git a/tests/code_execution/test_transport.py b/tests/code_execution/test_transport.py new file mode 100644 index 0000000..e0c485f --- /dev/null +++ b/tests/code_execution/test_transport.py @@ -0,0 +1,463 @@ +"""Tests for the driver transport protocol handler.""" + +from __future__ import annotations + +import asyncio +import json +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from pydantic_harness.environments._driver import ( + DriverBasedEnvironment, + DriverTransport, + ExecutionProcessTransport, + _StdoutSignal, # pyright: ignore[reportPrivateUsage] +) +from pydantic_harness.toolsets.code_execution._abstract import ( + CodeExecutionTimeout, + CodeRuntimeError, + CodeSyntaxError, +) + +pytestmark = pytest.mark.anyio + + +# --- Helpers --- + + +class _MockTransport(DriverTransport): + """Mock transport implementing the public DriverTransport interface.""" + + def __init__(self, *, stderr: bytes = b''): + self._stderr = stderr + + async def read_line(self) -> bytes: + return b'' # pragma: no cover + + async def write_line(self, data: bytes) -> None: + pass + + async def read_stderr(self) -> bytes: + return self._stderr + + async def kill(self) -> None: + pass + + +async def _handle(data: bytes, *, stderr: bytes = b'') -> Any: + """Create a resolved task and feed it to _handle_stdout.""" + + async def _return() -> bytes: + return data + + task: asyncio.Task[bytes] = asyncio.ensure_future(_return()) + await task + return await DriverBasedEnvironment._handle_stdout( # pyright: ignore[reportPrivateUsage] + task, _MockTransport(stderr=stderr), AsyncMock(), {}, {} + ) + + +class _ScriptedTransport(DriverTransport): + """Transport that feeds pre-programmed NDJSON messages. + + Messages are consumed in order from `lines`. When a message is a `call`, + the transport waits for a result to be written back via `write_line` + before yielding the next message. + """ + + def __init__(self, lines: list[dict[str, Any]], *, stderr: bytes = b''): + self._lines = [json.dumps(m).encode() + b'\n' for m in lines] + self._index = 0 + self._stderr = stderr + self._written: list[bytes] = [] + self._killed = False + self._result_event: asyncio.Event = asyncio.Event() + # Pre-set so non-call messages proceed immediately + self._result_event.set() + + async def read_line(self) -> bytes: + # Wait for any pending result write before yielding the next line + await self._result_event.wait() + if self._index >= len(self._lines): + return b'' # pragma: no cover + line = self._lines[self._index] + self._index += 1 + # If this was a call message, block until the host writes a result + msg = json.loads(line) + if msg.get('type') == 'call': + self._result_event.clear() + return line + + async def write_line(self, data: bytes) -> None: + self._written.append(data) + self._result_event.set() + + async def read_stderr(self) -> bytes: + return self._stderr # pragma: no cover + + async def kill(self) -> None: + self._killed = True + + +class _ScriptedEnvironment(DriverBasedEnvironment): + """Environment that uses a _ScriptedTransport.""" + + def __init__(self, lines: list[dict[str, Any]], **kwargs: Any): + self._transport = _ScriptedTransport(lines, **kwargs) + + @property + def capabilities(self) -> frozenset[Any]: + return frozenset({'run_python', 'run_python_with_functions'}) # pragma: no cover + + async def _copy_driver(self) -> None: + pass + + async def _start_driver(self, init_msg: dict[str, Any]) -> DriverTransport: + return self._transport + + +class _ErrorTransport(DriverTransport): + """Transport that raises TypeError on read_line.""" + + async def read_line(self) -> bytes: + raise TypeError('bad read') + + async def write_line(self, data: bytes) -> None: + pass # pragma: no cover + + async def read_stderr(self) -> bytes: + return b'' # pragma: no cover + + async def kill(self) -> None: + pass + + +class _ErrorEnvironment(DriverBasedEnvironment): + """Environment that uses _ErrorTransport.""" + + @property + def capabilities(self) -> frozenset[Any]: + return frozenset({'run_python'}) # pragma: no cover + + async def _copy_driver(self) -> None: + pass + + async def _start_driver(self, init_msg: dict[str, Any]) -> DriverTransport: + return _ErrorTransport() + + +class _BlockingTransport(DriverTransport): + """Transport whose read_line blocks forever (for timeout tests).""" + + async def read_line(self) -> bytes: + await asyncio.sleep(999) + return b'' # pragma: no cover + + async def write_line(self, data: bytes) -> None: + pass # pragma: no cover + + async def read_stderr(self) -> bytes: + return b'' # pragma: no cover + + async def kill(self) -> None: + pass + + +class _TimeoutEnvironment(DriverBasedEnvironment): + """Environment with a very short execution timeout.""" + + execution_timeout: float | None = 0.01 + + @property + def capabilities(self) -> frozenset[Any]: + return frozenset({'run_python', 'run_python_with_functions'}) # pragma: no cover + + async def _copy_driver(self) -> None: + pass + + async def _start_driver(self, init_msg: dict[str, Any]) -> DriverTransport: + return _BlockingTransport() + + +# --- _handle_stdout unit tests --- + + +async def test_handle_stdout_eof_with_stderr(): + """EOF on stdout reads stderr for error message.""" + with pytest.raises(CodeRuntimeError, match='segfault'): + await _handle(b'', stderr=b'segfault\n') + + +async def test_handle_stdout_eof_no_stderr(): + """EOF with no stderr gives default message.""" + with pytest.raises(CodeRuntimeError, match='exited unexpectedly'): + await _handle(b'') + + +async def test_handle_stdout_malformed_json(): + """Non-JSON output raises CodeRuntimeError.""" + with pytest.raises(CodeRuntimeError, match='Malformed'): + await _handle(b'not json\n') + + +async def test_handle_stdout_unknown_msg_type(): + """Unknown message type returns CONTINUE.""" + result = await _handle(json.dumps({'type': 'unknown'}).encode() + b'\n') + assert result == _StdoutSignal.CONTINUE + + +# --- Full execution loop tests --- + + +async def test_complete_message_returns_result(): + """A 'complete' message returns the final result.""" + env = _ScriptedEnvironment([{'type': 'complete', 'result': 'hello'}]) + result = await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + assert result == 'hello' + assert env._transport._killed # pyright: ignore[reportPrivateUsage] + + +async def test_error_message_raises_runtime_error(): + """An 'error' message raises CodeRuntimeError.""" + env = _ScriptedEnvironment([{'type': 'error', 'error': 'boom'}]) + with pytest.raises(CodeRuntimeError, match='boom'): + await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + + +async def test_syntax_error_message_raises_syntax_error(): + """An 'error' message with error_type='syntax' raises CodeSyntaxError.""" + env = _ScriptedEnvironment([{'type': 'error', 'error_type': 'syntax', 'error': 'bad syntax'}]) + with pytest.raises(CodeSyntaxError, match='bad syntax'): + await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + + +async def test_function_call_and_result(): + """A 'call' message dispatches a function callback, then sends the result back.""" + + async def callback(call: Any) -> str: + assert call.function_name == 'my_func' + assert call.kwargs == {'x': 1} + return 'callback_result' + + env = _ScriptedEnvironment( + [ + {'type': 'call', 'id': 1, 'function': 'my_func', 'kwargs': {'x': 1}}, + {'type': 'complete', 'result': 'done'}, + ] + ) + result = await env.run_python_with_functions( + 'my_func(x=1)', + function_callback=callback, + functions={}, + referenced_types=[], + ) + assert result == 'done' + # Verify a result was written back + assert len(env._transport._written) > 0 # pyright: ignore[reportPrivateUsage] + written_msg = json.loads(env._transport._written[0]) # pyright: ignore[reportPrivateUsage] + assert written_msg == {'type': 'result', 'id': 1, 'result': 'callback_result'} + + +async def test_calls_ready_continues(): + """A 'calls_ready' message continues the loop.""" + env = _ScriptedEnvironment( + [ + {'type': 'calls_ready'}, + {'type': 'complete', 'result': 42}, + ] + ) + result = await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + assert result == 42 + + +async def test_execution_timeout(): + """Timeout raises CodeExecutionTimeout.""" + env = _TimeoutEnvironment() + with pytest.raises(CodeExecutionTimeout, match='timed out'): + await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + + +async def test_run_catches_non_standard_exception(): + """Non-standard exceptions from transport are wrapped in CodeRuntimeError.""" + env = _ErrorEnvironment() + with pytest.raises(CodeRuntimeError, match='Driver communication error.*bad read'): + await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + + +async def test_error_reraised_from_run_python_with_functions(): + """CodeRuntimeError and CodeSyntaxError are re-raised directly (not wrapped).""" + env = _ScriptedEnvironment([{'type': 'error', 'error_type': 'syntax', 'error': 'parse fail'}]) + with pytest.raises(CodeSyntaxError, match='parse fail'): + await env.run_python_with_functions( + 'x = 1', + function_callback=AsyncMock(), + functions={}, + referenced_types=[], + ) + + +async def test_tool_error_unwrapped(): + """When a tool callback raises, the cause is unwrapped from _ToolError.""" + + async def bad_callback(call: Any) -> str: + raise ValueError('tool exploded') + + env = _ScriptedEnvironment( + [ + {'type': 'call', 'id': 1, 'function': 'exploder', 'kwargs': {}}, + {'type': 'complete', 'result': 'never reached'}, + ] + ) + with pytest.raises(ValueError, match='tool exploded'): + await env.run_python_with_functions( + 'exploder()', + function_callback=bad_callback, + functions={}, + referenced_types=[], + ) + + +# --- ExecutionProcessTransport tests --- + + +class _MockProcess: + """Minimal mock of ExecutionProcess for testing ExecutionProcessTransport.""" + + def __init__(self, chunks: list[bytes]): + self._chunks = list(chunks) + self._sent: list[bytes] = [] + self._killed = False + + async def recv(self, timeout: float | None = None) -> bytes: + if not self._chunks: + return b'' + return self._chunks.pop(0) + + async def recv_stderr(self, timeout: float | None = None) -> bytes: + raise OSError('stderr error') + + async def send(self, data: bytes) -> None: + self._sent.append(data) + + async def kill(self) -> None: + self._killed = True + + +async def test_execution_process_transport_read_line(): + """read_line accumulates chunks until newline.""" + proc = _MockProcess([b'hel', b'lo\nwo', b'rld\n']) + transport = ExecutionProcessTransport(proc) # pyright: ignore[reportArgumentType] + line = await transport.read_line() + assert line == b'hello\n' + # Second read should use the buffered data + line2 = await transport.read_line() + assert line2 == b'world\n' + + +async def test_execution_process_transport_read_line_eof(): + """read_line returns remaining buffer on EOF (empty recv).""" + proc = _MockProcess([b'partial']) + transport = ExecutionProcessTransport(proc) # pyright: ignore[reportArgumentType] + line = await transport.read_line() + assert line == b'partial' + + +async def test_execution_process_transport_write_line(): + """write_line delegates to process.send.""" + proc = _MockProcess([]) + transport = ExecutionProcessTransport(proc) # pyright: ignore[reportArgumentType] + await transport.write_line(b'data\n') + assert proc._sent == [b'data\n'] # pyright: ignore[reportPrivateUsage] + + +async def test_execution_process_transport_read_stderr(): + """read_stderr returns empty bytes when the process raises.""" + proc = _MockProcess([]) + transport = ExecutionProcessTransport(proc) # pyright: ignore[reportArgumentType] + result = await transport.read_stderr() + assert result == b'' + + +async def test_execution_process_transport_kill(): + """kill delegates to process.kill.""" + proc = _MockProcess([]) + transport = ExecutionProcessTransport(proc) # pyright: ignore[reportArgumentType] + await transport.kill() + assert proc._killed # pyright: ignore[reportPrivateUsage] + + +# --- Default _start_driver test --- + + +async def test_default_start_driver(): + """The default _start_driver creates a process and sends the init message.""" + sent_data: list[bytes] = [] + + class _FakeProcess: + async def __aenter__(self) -> _FakeProcess: + return self + + async def __aexit__(self, *args: Any) -> None: + pass # pragma: no cover + + async def recv(self, timeout: float | None = None) -> bytes: + return b'' # pragma: no cover + + async def send(self, data: bytes) -> None: + sent_data.append(data) + + async def recv_stderr(self, timeout: float | None = None) -> bytes: + return b'' # pragma: no cover + + async def kill(self) -> None: + pass # pragma: no cover + + class _FakeDriverEnvironment(DriverBasedEnvironment): + @property + def capabilities(self) -> frozenset[Any]: + return frozenset({'run_python', 'run_python_with_functions'}) # pragma: no cover + + async def _copy_driver(self) -> None: + pass + + async def create_process(self, command: str, **kwargs: Any) -> Any: + return _FakeProcess() + + env = _FakeDriverEnvironment() + init_msg: dict[str, Any] = {'type': 'init', 'code': 'x = 1', 'functions': []} + transport = await env._start_driver(init_msg) # pyright: ignore[reportPrivateUsage] + assert isinstance(transport, ExecutionProcessTransport) + assert len(sent_data) == 1 + assert json.loads(sent_data[0]) == init_msg diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3b6528b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +import pytest + +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar('T') + + def IsInstance(arg: type[T]) -> T: ... + def IsDatetime(*args: Any, **kwargs: Any) -> datetime: ... + def IsFloat(*args: Any, **kwargs: Any) -> float: ... + def IsInt(*args: Any, **kwargs: Any) -> int: ... + def IsNow(*args: Any, **kwargs: Any) -> datetime: ... + def IsStr(*args: Any, **kwargs: Any) -> str: ... + def IsBytes(*args: Any, **kwargs: Any) -> bytes: ... + def IsList(*args: T, **kwargs: Any) -> list[T]: ... +else: + from dirty_equals import IsBytes, IsDatetime, IsFloat, IsInstance, IsInt, IsList, IsNow as _IsNow, IsStr + + def IsNow(*args: Any, **kwargs: Any): + if 'delta' not in kwargs: + kwargs['delta'] = 10 + return _IsNow(*args, **kwargs) + +__all__ = ( + 'IsDatetime', + 'IsFloat', + 'IsNow', + 'IsStr', + 'IsBytes', + 'IsInt', + 'IsInstance', + 'IsList', +) + +pytest_plugins = [] diff --git a/tests/test_environments.py b/tests/test_environments.py new file mode 100644 index 0000000..c7cfc26 --- /dev/null +++ b/tests/test_environments.py @@ -0,0 +1,2205 @@ +"""Tests for pydantic_harness.environments -- ExecutionEnvironment, ExecutionEnvironmentToolset, LocalEnvironment, and MemoryEnvironment.""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import ToolCallPart +from pydantic_ai._run_context import RunContext +from pydantic_ai._tool_manager import ToolManager +from pydantic_harness.environments import ExecutionEnvironmentToolset, ExecutionResult, FileInfo +from pydantic_harness.environments._base import ( + apply_edit, + build_glob_cmd, + build_grep_cmd, + build_read_file_cmd, + filter_grep_count_output, + format_lines, + glob_match, + parse_glob_output, + shell_escape, +) +from pydantic_harness.environments.local import LocalEnvironment +from pydantic_harness.environments.memory import MemoryEnvironment +from pydantic_ai.exceptions import UnexpectedModelBehavior +from pydantic_ai.models.test import TestModel +from pydantic_ai.usage import RunUsage + +pytestmark = pytest.mark.anyio + + +def build_run_context(deps: Any = None, run_step: int = 0) -> RunContext[Any]: + return RunContext( + deps=deps, + model=TestModel(), + usage=RunUsage(), + prompt=None, + messages=[], + run_step=run_step, + ) + + +# --- Data types --- + + +def test_execute_result(): + result = ExecutionResult(output='hello\n', exit_code=0) + assert result.output == 'hello\n' + assert result.exit_code == 0 + assert result.truncated is False + + +def test_execute_result_truncated(): + result = ExecutionResult(output='data', exit_code=1, truncated=True) + assert result.truncated is True + + +def test_file_info(): + info = FileInfo(name='test.py', path='src/test.py', is_dir=False, size=42) + assert info.name == 'test.py' + assert info.is_dir is False + assert info.size == 42 + + +def test_file_info_directory(): + info = FileInfo(name='src', path='src', is_dir=True) + assert info.is_dir is True + assert info.size is None + + +# --- LocalEnvironment: execute --- + + +async def test_local_execute_basic(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + result = await env.shell('echo hello') + assert result.exit_code == 0 + assert 'hello' in result.output + + +async def test_local_execute_exit_code(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + result = await env.shell('exit 42') + assert result.exit_code == 42 + + +async def test_local_execute_timeout(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + result = await env.shell('sleep 10', timeout=0.5) + assert result.exit_code == -1 + assert 'timed out' in result.output.lower() + + +async def test_local_execute_stderr(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + result = await env.shell('echo error >&2') + assert 'error' in result.output + + +# --- LocalEnvironment: environment variables --- + + +async def test_local_env_vars_baseline(tmp_path: Path): + async with LocalEnvironment(tmp_path, env_vars={'MY_VAR': 'baseline'}) as env: + result = await env.shell('echo $MY_VAR') + assert 'baseline' in result.output + + +async def test_local_env_vars_per_call(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + result = await env.shell('echo $CALL_VAR', env={'CALL_VAR': 'per_call'}) + assert 'per_call' in result.output + + +async def test_local_env_vars_merged(tmp_path: Path): + async with LocalEnvironment(tmp_path, env_vars={'BASE': 'one'}) as env: + result = await env.shell('echo $BASE $EXTRA', env={'EXTRA': 'two'}) + assert 'one' in result.output + assert 'two' in result.output + + +async def test_local_env_vars_per_call_overrides_baseline(tmp_path: Path): + async with LocalEnvironment(tmp_path, env_vars={'VAR': 'old'}) as env: + result = await env.shell('echo $VAR', env={'VAR': 'new'}) + assert 'new' in result.output + assert 'old' not in result.output + + +async def test_local_inherit_env_true(tmp_path: Path): + os.environ['_TEST_INHERIT_CHECK'] = 'inherited' + try: + async with LocalEnvironment(tmp_path, inherit_env=True) as env: + result = await env.shell('echo $_TEST_INHERIT_CHECK') + assert 'inherited' in result.output + finally: + del os.environ['_TEST_INHERIT_CHECK'] + + +async def test_local_inherit_env_false(tmp_path: Path): + os.environ['_TEST_INHERIT_CHECK'] = 'should_not_see' + try: + async with LocalEnvironment(tmp_path, inherit_env=False) as env: + result = await env.shell('echo x${_TEST_INHERIT_CHECK}x') + assert result.output.strip() == 'xx' + finally: + del os.environ['_TEST_INHERIT_CHECK'] + + +async def test_local_inherit_env_false_with_explicit_vars(tmp_path: Path): + async with LocalEnvironment(tmp_path, env_vars={'ONLY_THIS': 'yes'}, inherit_env=False) as env: + result = await env.shell('/bin/echo $ONLY_THIS') + assert 'yes' in result.output + + +# --- LocalEnvironment: file operations --- + + +async def test_local_write_and_read(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('test.txt', 'line one\nline two\n') + content = await env.read_file('test.txt') + assert isinstance(content, str) + assert 'line one' in content + assert 'line two' in content + + +async def test_local_read_line_numbers(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('numbered.txt', 'alpha\nbeta\ngamma\n') + content = await env.read_file('numbered.txt') + assert content == snapshot("""\ + 1\talpha + 2\tbeta + 3\tgamma +""") + + +async def test_local_read_with_offset_limit(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + lines = '\n'.join(f'line {i}' for i in range(20)) + await env.write_file('long.txt', lines) + + content = await env.read_file('long.txt', offset=5, limit=3) + assert content == snapshot("""\ + 6\tline 5 + 7\tline 6 + 8\tline 7 +... (12 more lines. Use offset=8 to continue reading.) +""") + + +async def test_local_read_continuation_hint(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + lines = '\n'.join(f'line {i}' for i in range(20)) + await env.write_file('long.txt', lines) + + content = await env.read_file('long.txt', offset=0, limit=5) + assert content == snapshot("""\ + 1\tline 0 + 2\tline 1 + 3\tline 2 + 4\tline 3 + 5\tline 4 +... (15 more lines. Use offset=5 to continue reading.) +""") + + +async def test_local_read_offset_out_of_bounds(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('short.txt', 'one\ntwo\n') + with pytest.raises(ValueError, match='Offset 100 exceeds file length'): + await env.read_file('short.txt', offset=100) + + +async def test_local_read_directory_error(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + (tmp_path / 'subdir').mkdir() + with pytest.raises(FileNotFoundError, match='is a directory'): + await env.read_file('subdir') + + +async def test_local_read_nonexistent(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + with pytest.raises(FileNotFoundError): + await env.read_file('nonexistent.txt') + + +async def test_local_write_creates_parent_dirs(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('deep/nested/dir/file.txt', 'content') + content = await env.read_file('deep/nested/dir/file.txt') + assert isinstance(content, str) + assert 'content' in content + + +async def test_local_write_binary(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('binary.bin', b'\x00\x01\x02\x03') + assert (tmp_path / 'binary.bin').read_bytes() == b'\x00\x01\x02\x03' + + +async def test_local_read_file_bytes(tmp_path: Path): + # Create a minimal PNG (1x1 transparent pixel) + png_data = ( + b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01' + b'\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89' + b'\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01' + b'\r\n\xb4\x00\x00\x00\x00IEND\xaeB`\x82' + ) + async with LocalEnvironment(tmp_path) as env: + await env.write_file('image.png', png_data) + result = await env.read_file('image.png') + assert isinstance(result, bytes) + assert result == png_data + + +# --- LocalEnvironment: edit_file --- + + +async def test_local_edit_single_replacement(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('edit.txt', 'foo bar baz') + count = await env.replace_str('edit.txt', 'bar', 'BAR') + assert count == 1 + content = (tmp_path / 'edit.txt').read_text() + assert content == 'foo BAR baz' + + +async def test_local_edit_replace_all(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('edit.txt', 'aaa bbb aaa') + count = await env.replace_str('edit.txt', 'aaa', 'xxx', replace_all=True) + assert count == 2 + content = (tmp_path / 'edit.txt').read_text() + assert content == 'xxx bbb xxx' + + +async def test_local_edit_not_found(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('edit.txt', 'hello world') + with pytest.raises(ValueError, match='not found'): + await env.replace_str('edit.txt', 'missing', 'replacement') + + +async def test_local_edit_ambiguous_without_replace_all(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('edit.txt', 'dup dup dup') + with pytest.raises(ValueError, match='3 times'): + await env.replace_str('edit.txt', 'dup', 'unique') + + +async def test_local_edit_nonexistent_file(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + with pytest.raises(FileNotFoundError): + await env.replace_str('missing.txt', 'old', 'new') + + +async def test_local_edit_multiline(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('code.py', 'def foo():\n return "old"\n\nprint("test")\n') + count = await env.replace_str('code.py', 'def foo():\n return "old"', 'def foo():\n return "new"') + assert count == 1 + content = (tmp_path / 'code.py').read_text() + assert 'return "new"' in content + assert 'return "old"' not in content + assert 'print("test")' in content + + +# --- LocalEnvironment: ls --- + + +async def test_local_ls(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('a.txt', 'a') + await env.write_file('b.txt', 'b') + (tmp_path / 'subdir').mkdir() + + entries = await env.ls('.') + names = {e.name for e in entries} + assert 'a.txt' in names + assert 'b.txt' in names + assert 'subdir' in names + + dirs = [e for e in entries if e.is_dir] + files = [e for e in entries if not e.is_dir] + assert any(d.name == 'subdir' for d in dirs) + assert all(f.size is not None and f.size > 0 for f in files) + + +async def test_local_ls_not_a_directory(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('file.txt', 'content') + with pytest.raises(NotADirectoryError): + await env.ls('file.txt') + + +# --- LocalEnvironment: glob --- + + +async def test_local_glob(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('src/main.py', '# main') + await env.write_file('src/utils.py', '# utils') + await env.write_file('src/data.json', '{}') + + matches = await env.glob('**/*.py') + assert len(matches) == 2 + assert any('main.py' in m for m in matches) + assert any('utils.py' in m for m in matches) + assert not any('data.json' in m for m in matches) + + +async def test_local_glob_no_matches(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + matches = await env.glob('**/*.rs') + assert matches == [] + + +# --- LocalEnvironment: grep --- + + +async def test_local_grep(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('a.py', 'def hello():\n pass\n') + await env.write_file('b.py', 'x = 1\n') + + result = await env.grep('hello') + assert 'a.py' in result + assert 'hello' in result + assert 'b.py' not in result + + +async def test_local_grep_with_glob_pattern(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('code.py', 'target = 1\n') + await env.write_file('code.js', 'target = 2\n') + + result = await env.grep('target', glob_pattern='*.py') + assert 'code.py' in result + assert 'code.js' not in result + + +async def test_local_grep_line_numbers(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('test.txt', 'alpha\nbeta\ngamma\nbeta\n') + + result = await env.grep('beta') + assert result == snapshot('test.txt:2:beta\ntest.txt:4:beta') + + +async def test_local_grep_no_matches(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('test.txt', 'nothing interesting') + result = await env.grep('nonexistent_pattern') + assert result == '' + + +async def test_local_grep_skips_hidden_files(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('visible.py', 'target_string\n') + (tmp_path / '.hidden').mkdir() + (tmp_path / '.hidden' / 'secret.py').write_text('target_string\n') + (tmp_path / '.dotfile').write_text('target_string\n') + + result = await env.grep('target_string') + assert 'visible.py' in result + assert '.hidden' not in result + assert '.dotfile' not in result + + +# --- LocalEnvironment: create_process --- + + +async def test_local_create_process(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + proc = await env.create_process('echo interactive') + async with proc: + data = await proc.recv(timeout=5) + assert b'interactive' in data + + +async def test_local_create_process_env(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + proc = await env.create_process('echo $PROC_VAR', env={'PROC_VAR': 'from_process'}) + async with proc: + data = await proc.recv(timeout=5) + assert b'from_process' in data + + +async def test_local_create_process_stdin(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + # Use head -1 so the process exits after reading one line + proc = await env.create_process('head -1') + async with proc: + await proc.send(b'hello from stdin\n') + data = await proc.recv(timeout=5) + assert b'hello from stdin' in data + + +async def test_local_process_wait(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + proc = await env.create_process('exit 7') + async with proc: + rc = await proc.wait(timeout=5) + assert rc == 7 + + +async def test_local_process_kill(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + proc = await env.create_process('sleep 60') + # Don't use async with -- we want to test manual kill + await proc.kill() + assert proc.returncode is not None + + +# --- LocalEnvironment: path traversal --- + + +async def test_local_path_traversal_blocked(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + with pytest.raises(PermissionError, match='outside the environment root'): + await env.read_file('../../../etc/passwd') + + +async def test_local_path_traversal_write_blocked(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + with pytest.raises(PermissionError, match='outside the environment root'): + await env.write_file('../escape.txt', 'malicious') + + +# --- LocalEnvironment: creates root dir --- + + +async def test_local_creates_root_dir(tmp_path: Path): + root = tmp_path / 'new_root' + assert not root.exists() + async with LocalEnvironment(root) as env: + assert root.exists() + result = await env.shell('echo works') + assert 'works' in result.output + + +# --- ExecutionEnvironmentToolset --- + + +async def test_toolset_tool_names(): + toolset = ExecutionEnvironmentToolset(LocalEnvironment('.')) + tool_names = sorted(toolset.tools.keys()) + assert tool_names == snapshot(['glob', 'grep', 'ls', 'read_file', 'replace_str', 'shell', 'write_file']) + + +async def test_toolset_include_flags(): + toolset = ExecutionEnvironmentToolset( + LocalEnvironment('.'), + include=frozenset(), + ) + assert toolset.tools == {} + + +async def test_toolset_include_shell_only(): + toolset = ExecutionEnvironmentToolset( + LocalEnvironment('.'), + include=frozenset({'shell'}), + ) + assert sorted(toolset.tools.keys()) == ['shell'] + + +async def test_toolset_bash_tool(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='shell', args={'command': 'echo hello'})) + assert result == snapshot("""\ +hello + +Exit code: 0\ +""") + + +async def test_toolset_read_write_tools(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + # Write + write_result = await manager.handle_call( + ToolCallPart(tool_name='write_file', args={'path': 'test.txt', 'content': 'hello world'}) + ) + assert write_result == snapshot('File written: test.txt') + + # Read + read_result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': 'test.txt'})) + assert read_result == snapshot(' 1\thello world\n') + + +async def test_toolset_edit_retry_on_error(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env, max_retries=0) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('test.txt', 'content') + + # Edit with non-matching string: ModelRetry is raised by tool, but with max_retries=0 + # the ToolManager wraps it into UnexpectedModelBehavior + with pytest.raises(UnexpectedModelBehavior, match='exceeded max retries count of 0'): + await manager.handle_call( + ToolCallPart( + tool_name='replace_str', + args={'path': 'test.txt', 'old': 'nonexistent', 'new': 'replacement'}, + ) + ) + + +async def test_toolset_glob_tool(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('a.py', '# a') + await env.write_file('b.py', '# b') + + result = await manager.handle_call(ToolCallPart(tool_name='glob', args={'pattern': '*.py'})) + assert result == snapshot("""\ +a.py +b.py\ +""") + + +async def test_toolset_grep_tool(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('search.py', 'def find_me():\n pass\n') + + result = await manager.handle_call(ToolCallPart(tool_name='grep', args={'pattern': 'find_me'})) + assert result == snapshot('search.py:1:def find_me():') + + +# --- ExecutionEnvironmentToolset: error handling --- + + +async def test_toolset_read_nonexistent_returns_error(tmp_path: Path): + """read_file on a nonexistent file returns an error string instead of crashing.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': 'nope.txt'})) + assert 'Error:' in str(result) + + +async def test_toolset_read_path_traversal_returns_error(tmp_path: Path): + """read_file with path traversal returns an error string.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': '../../etc/passwd'})) + assert 'Error:' in str(result) + + +async def test_toolset_write_path_traversal_returns_error(tmp_path: Path): + """write_file with path traversal returns an error string.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call( + ToolCallPart(tool_name='write_file', args={'path': '../../tmp/evil.txt', 'content': 'bad'}) + ) + assert 'Error:' in str(result) + + +async def test_toolset_glob_path_traversal_returns_error(tmp_path: Path): + """glob with path traversal returns an error string.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call( + ToolCallPart(tool_name='glob', args={'pattern': '*.py', 'path': '../../etc'}) + ) + assert 'Error:' in str(result) + + +async def test_toolset_grep_invalid_regex_returns_error(tmp_path: Path): + """grep with invalid regex returns an error string.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('test.txt', 'content') + + result = await manager.handle_call(ToolCallPart(tool_name='grep', args={'pattern': '[invalid'})) + assert 'Error:' in str(result) + + +async def test_toolset_read_offset_out_of_bounds_returns_error(tmp_path: Path): + """read_file with offset past EOF returns an error string.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('short.txt', 'one\ntwo\n') + + result = await manager.handle_call( + ToolCallPart(tool_name='read_file', args={'path': 'short.txt', 'offset': 100}) + ) + assert 'Error:' in str(result) + assert 'Offset 100 exceeds' in str(result) + + +async def test_toolset_read_continuation_hint(tmp_path: Path): + """read_file includes continuation hint when there are more lines.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + lines = '\n'.join(f'line {i}' for i in range(20)) + await env.write_file('long.txt', lines) + + result = await manager.handle_call( + ToolCallPart(tool_name='read_file', args={'path': 'long.txt', 'offset': 0, 'limit': 5}) + ) + assert result == snapshot("""\ + 1 line 0 + 2 line 1 + 3 line 2 + 4 line 3 + 5 line 4 +... (15 more lines. Use offset=5 to continue reading.) +""") + + +# --- ExecutionEnvironmentToolset: approval flags --- + + +async def test_toolset_require_shell_approval(): + """require_shell_approval sets requires_approval on the shell tool.""" + toolset = ExecutionEnvironmentToolset(require_shell_approval=True) + ctx = build_run_context(None) + tools = await toolset.get_tools(ctx) + assert tools['shell'].tool_def.kind == 'unapproved' + # Other tools should be normal + assert tools['read_file'].tool_def.kind == 'function' + + +async def test_toolset_require_write_approval(): + """require_write_approval sets requires_approval on write_file and replace_str.""" + toolset = ExecutionEnvironmentToolset(require_write_approval=True) + ctx = build_run_context(None) + tools = await toolset.get_tools(ctx) + assert tools['write_file'].tool_def.kind == 'unapproved' + assert tools['replace_str'].tool_def.kind == 'unapproved' + # read_file and search tools should NOT require approval + assert tools['read_file'].tool_def.kind == 'function' + assert tools['glob'].tool_def.kind == 'function' + assert tools['grep'].tool_def.kind == 'function' + + +async def test_toolset_default_no_approval(): + """By default, no tools require approval.""" + toolset = ExecutionEnvironmentToolset() + ctx = build_run_context(None) + tools = await toolset.get_tools(ctx) + for tool in tools.values(): + assert tool.tool_def.kind == 'function' + + +# --- ExecutionEnvironmentToolset: environment management --- + + +async def test_toolset_environment_property(): + env = LocalEnvironment('.') + toolset = ExecutionEnvironmentToolset(env) + assert toolset.environment is env + assert toolset.required_environment is env + + +async def test_toolset_no_environment_returns_none(): + toolset = ExecutionEnvironmentToolset() + assert toolset.environment is None + + +async def test_toolset_no_environment_required_raises(): + toolset = ExecutionEnvironmentToolset() + with pytest.raises(RuntimeError, match='No execution environment configured'): + _ = toolset.required_environment + + +async def test_toolset_use_environment(): + env1 = LocalEnvironment('/tmp/env1') + env2 = LocalEnvironment('/tmp/env2') + toolset = ExecutionEnvironmentToolset(env1) + + assert toolset.environment is env1 + with toolset.use_environment(env2): + assert toolset.environment is env2 + assert toolset.environment is env1 + + +async def test_toolset_use_environment_no_default(): + env = LocalEnvironment('.') + toolset = ExecutionEnvironmentToolset() + + assert toolset.environment is None + + with toolset.use_environment(env): + assert toolset.environment is env + + assert toolset.environment is None + + +async def test_toolset_instructions(): + """Environment instructions is accessible for each tool.""" + env = LocalEnvironment('.') + # LocalEnvironment returns None for all tool descriptions by default + assert env.instructions('shell') is None + assert env.instructions('read_file') is None + + +async def test_toolset_tool_name_conflict_hint(): + toolset = ExecutionEnvironmentToolset(LocalEnvironment('.')) + assert 'PrefixedToolset' in toolset.tool_name_conflict_hint + + +# --- ExecutionEnvironmentToolset: lifecycle --- + + +async def test_toolset_lifecycle(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + + async with toolset: + result = await env.shell('echo lifecycle') + assert 'lifecycle' in result.output + + +# --- ExecutionEnvironmentToolset: image support --- + + +async def test_toolset_image_support_disabled(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env, image_support=False) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('photo.png', b'\x89PNG\r\n\x1a\n') + result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': 'photo.png'})) + assert result == snapshot('[Image file: photo.png — image_support is disabled on this toolset]') + + +# --- LocalEnvironment: grep output modes --- + + +async def test_local_grep_files_with_matches(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('a.py', 'target = 1\nother = 2\n') + await env.write_file('b.py', 'target = 3\ntarget = 4\n') + await env.write_file('c.py', 'nothing here\n') + + result = await env.grep('target', output_mode='files_with_matches') + lines = result.strip().splitlines() + assert sorted(lines) == ['a.py', 'b.py'] + + +async def test_local_grep_count(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('a.py', 'target = 1\nother = 2\n') + await env.write_file('b.py', 'target = 3\ntarget = 4\n') + await env.write_file('c.py', 'nothing here\n') + + result = await env.grep('target', output_mode='count') + lines = sorted(result.strip().splitlines()) + assert lines == ['a.py:1', 'b.py:2'] + + +async def test_local_grep_content_default(tmp_path: Path): + """Default output_mode is 'content' with file:line:text format.""" + async with LocalEnvironment(tmp_path) as env: + await env.write_file('test.py', 'hello\nworld\n') + + result = await env.grep('hello') + assert result == snapshot('test.py:1:hello') + + +# --- LocalEnvironment: binary file detection --- + + +async def test_local_grep_skips_binary_files(tmp_path: Path): + async with LocalEnvironment(tmp_path) as env: + await env.write_file('text.py', 'findme = True\n') + await env.write_file('binary.pyc', b'\x00\x01\x02findme\x03\x04') + + result = await env.grep('findme') + assert 'text.py' in result + assert 'binary.pyc' not in result + + +async def test_local_grep_binary_detection_first_8kb(tmp_path: Path): + """Binary detection checks only the first 8KB.""" + async with LocalEnvironment(tmp_path) as env: + # File with null byte after 8KB -- should be treated as text + content = 'findme\n' + ('x' * 8200) + '\x00' + await env.write_file('mostly_text.txt', content) + + result = await env.grep('findme') + assert 'mostly_text.txt' in result + + +# --- Toolset: grep output_mode --- + + +async def test_toolset_grep_files_with_matches(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('a.py', 'target = 1\n') + await env.write_file('b.py', 'other = 2\n') + + result = await manager.handle_call( + ToolCallPart(tool_name='grep', args={'pattern': 'target', 'output_mode': 'files_with_matches'}) + ) + assert result == snapshot('a.py') + + +async def test_toolset_grep_count(tmp_path: Path): + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('a.py', 'x = 1\nx = 2\nx = 3\n') + + result = await manager.handle_call( + ToolCallPart(tool_name='grep', args={'pattern': 'x', 'output_mode': 'count'}) + ) + assert result == snapshot('a.py:3') + + +# --- MemoryEnvironment --- + + +async def test_memory_read_write(): + async with MemoryEnvironment() as env: + await env.write_file('test.txt', 'hello world\n') + content = await env.read_file('test.txt') + assert content == snapshot("""\ + 1\thello world +""") + + +async def test_memory_initial_files(): + env = MemoryEnvironment(files={'a.txt': 'alpha', 'b.txt': 'beta'}) + async with env: + a = await env.read_file('a.txt') + assert isinstance(a, str) + assert 'alpha' in a + b = await env.read_file('b.txt') + assert isinstance(b, str) + assert 'beta' in b + + +async def test_memory_read_nonexistent(): + async with MemoryEnvironment() as env: + with pytest.raises(FileNotFoundError): + await env.read_file('nope.txt') + + +async def test_memory_read_directory_error(): + env = MemoryEnvironment(files={'dir/file.txt': 'content'}) + async with env: + with pytest.raises(FileNotFoundError, match='is a directory'): + await env.read_file('dir') + + +async def test_memory_read_offset_limit(): + lines = '\n'.join(f'line {i}' for i in range(20)) + env = MemoryEnvironment(files={'long.txt': lines}) + async with env: + content = await env.read_file('long.txt', offset=5, limit=3) + assert isinstance(content, str) + assert 'line 5' in content + assert 'line 7' in content + assert 'line 4' not in content + assert 'line 8' not in content + + +async def test_memory_read_continuation_hint(): + lines = '\n'.join(f'line {i}' for i in range(20)) + env = MemoryEnvironment(files={'long.txt': lines}) + async with env: + content = await env.read_file('long.txt', offset=0, limit=5) + assert isinstance(content, str) + assert '15 more lines' in content + assert 'offset=5' in content + + +async def test_memory_read_offset_out_of_bounds(): + env = MemoryEnvironment(files={'short.txt': 'one\ntwo\n'}) + async with env: + with pytest.raises(ValueError, match='Offset 100 exceeds'): + await env.read_file('short.txt', offset=100) + + +async def test_memory_edit_file(): + env = MemoryEnvironment(files={'code.py': 'old_value = 1'}) + async with env: + count = await env.replace_str('code.py', 'old_value', 'new_value') + assert count == 1 + content = await env.read_file('code.py') + assert isinstance(content, str) + assert 'new_value' in content + assert 'old_value' not in content + + +async def test_memory_edit_file_not_found(): + async with MemoryEnvironment() as env: + with pytest.raises(FileNotFoundError): + await env.replace_str('nope.txt', 'a', 'b') + + +async def test_memory_edit_string_not_found(): + env = MemoryEnvironment(files={'f.txt': 'hello'}) + async with env: + with pytest.raises(ValueError, match='not found'): + await env.replace_str('f.txt', 'missing', 'replacement') + + +async def test_memory_edit_ambiguous(): + env = MemoryEnvironment(files={'f.txt': 'dup dup dup'}) + async with env: + with pytest.raises(ValueError, match='3 times'): + await env.replace_str('f.txt', 'dup', 'x') + + +async def test_memory_edit_replace_all(): + env = MemoryEnvironment(files={'f.txt': 'aaa bbb aaa'}) + async with env: + count = await env.replace_str('f.txt', 'aaa', 'xxx', replace_all=True) + assert count == 2 + content = await env.read_file('f.txt') + assert isinstance(content, str) + assert 'xxx bbb xxx' in content + + +async def test_memory_ls(): + env = MemoryEnvironment( + files={ + 'a.txt': 'a', + 'b.txt': 'bb', + 'sub/c.txt': 'ccc', + } + ) + async with env: + entries = await env.ls('.') + names = {e.name for e in entries} + assert names == {'a.txt', 'b.txt', 'sub'} + + dirs = [e for e in entries if e.is_dir] + files = [e for e in entries if not e.is_dir] + assert len(dirs) == 1 + assert dirs[0].name == 'sub' + assert all(f.size is not None for f in files) + + +async def test_memory_ls_subdirectory(): + env = MemoryEnvironment(files={'sub/a.txt': 'a', 'sub/b.txt': 'b'}) + async with env: + entries = await env.ls('sub') + names = {e.name for e in entries} + assert names == {'a.txt', 'b.txt'} + + +async def test_memory_ls_not_a_directory(): + async with MemoryEnvironment() as env: + with pytest.raises(NotADirectoryError): + await env.ls('nonexistent') + + +async def test_memory_glob(): + env = MemoryEnvironment( + files={ + 'src/main.py': '# main', + 'src/utils.py': '# utils', + 'src/data.json': '{}', + } + ) + async with env: + matches = await env.glob('*.py', path='src') + assert sorted(matches) == ['src/main.py', 'src/utils.py'] + + +async def test_memory_glob_no_matches(): + env = MemoryEnvironment(files={'a.py': ''}) + async with env: + matches = await env.glob('*.rs') + assert matches == [] + + +async def test_memory_grep_content(): + env = MemoryEnvironment( + files={ + 'a.py': 'def hello():\n pass\n', + 'b.py': 'x = 1\n', + } + ) + async with env: + result = await env.grep('hello') + assert result == snapshot('a.py:1:def hello():') + + +async def test_memory_grep_files_with_matches(): + env = MemoryEnvironment( + files={ + 'a.py': 'target = 1\n', + 'b.py': 'target = 2\ntarget = 3\n', + 'c.py': 'nothing\n', + } + ) + async with env: + result = await env.grep('target', output_mode='files_with_matches') + lines = sorted(result.strip().splitlines()) + assert lines == ['a.py', 'b.py'] + + +async def test_memory_grep_count(): + env = MemoryEnvironment( + files={ + 'a.py': 'x = 1\n', + 'b.py': 'x = 2\nx = 3\n', + } + ) + async with env: + result = await env.grep('x', output_mode='count') + lines = sorted(result.strip().splitlines()) + assert lines == ['a.py:1', 'b.py:2'] + + +async def test_memory_grep_skips_binary(): + env = MemoryEnvironment( + files={ + 'text.py': 'findme = True\n', + 'binary.dat': b'\x00\x01findme\x02', + } + ) + async with env: + result = await env.grep('findme') + assert 'text.py' in result + assert 'binary.dat' not in result + + +async def test_memory_grep_skips_hidden(): + env = MemoryEnvironment( + files={ + 'visible.py': 'target\n', + '.hidden/secret.py': 'target\n', + } + ) + async with env: + result = await env.grep('target') + assert 'visible.py' in result + assert '.hidden' not in result + + +async def test_memory_grep_with_glob_pattern(): + env = MemoryEnvironment( + files={ + 'code.py': 'target\n', + 'code.js': 'target\n', + } + ) + async with env: + result = await env.grep('target', glob_pattern='*.py') + assert 'code.py' in result + assert 'code.js' not in result + + +async def test_memory_execute_with_handler(): + def handler(cmd: str) -> ExecutionResult: + return ExecutionResult(output=f'ran: {cmd}\n', exit_code=0) + + async with MemoryEnvironment(command_handler=handler) as env: + result = await env.shell('echo hello') + assert result.output == 'ran: echo hello\n' + assert result.exit_code == 0 + + +async def test_memory_execute_no_handler(): + async with MemoryEnvironment() as env: + with pytest.raises(RuntimeError, match='no command_handler'): + await env.shell('echo hello') + + +async def test_memory_create_process_not_supported(): + async with MemoryEnvironment() as env: + with pytest.raises(NotImplementedError): + await env.create_process('echo hello') + + +async def test_memory_write_binary(): + async with MemoryEnvironment() as env: + await env.write_file('data.bin', b'\x00\x01\x02') + # Non-image binary files are returned as text (decoded) + content = await env.read_file('data.bin') + assert isinstance(content, str) + + +async def test_memory_read_file_bytes(): + png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR' + env = MemoryEnvironment(files={'img.png': png_data}) + async with env: + result = await env.read_file('img.png') + assert isinstance(result, bytes) + assert result == png_data + + +# --- MemoryEnvironment with ExecutionEnvironmentToolset --- + + +async def test_memory_toolset_integration(): + """MemoryEnvironment works with ExecutionEnvironmentToolset for full agent testing.""" + env = MemoryEnvironment(files={'main.py': 'print("hello")\n'}) + toolset = ExecutionEnvironmentToolset(env, exclude=frozenset({'shell', 'run_code'})) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + # read_file + result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': 'main.py'})) + assert result == snapshot(' 1\tprint("hello")\n') + + # write_file + result = await manager.handle_call( + ToolCallPart(tool_name='write_file', args={'path': 'new.py', 'content': 'x = 1'}) + ) + assert result == snapshot('File written: new.py') + + # glob + result = await manager.handle_call(ToolCallPart(tool_name='glob', args={'pattern': '*.py'})) + assert result == snapshot("""\ +main.py +new.py\ +""") + + # grep + result = await manager.handle_call(ToolCallPart(tool_name='grep', args={'pattern': 'hello'})) + assert result == snapshot('main.py:1:print("hello")') + + +# --- Agent-level integration test --- + + +async def test_agent_with_execution_toolset(): + """Agent with ExecutionEnvironmentToolset runs end-to-end using TestModel and MemoryEnvironment.""" + from pydantic_ai import Agent + + env = MemoryEnvironment( + files={'data.txt': 'hello world\n'}, + command_handler=lambda cmd: ExecutionResult(output=f'executed: {cmd}\n', exit_code=0), + ) + toolset = ExecutionEnvironmentToolset(env) + + agent = Agent('test', toolsets=[toolset]) + + async with env: + result = await agent.run('Read the file data.txt') + # The TestModel will call tools and we verify it completes without error + assert result.output is not None + + +# pyright: reportPrivateUsage=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportUnknownVariableType=false, reportPossiblyUnboundVariable=false + + +# --- _base.py helper functions --- + + +def test_shell_escape(): + assert shell_escape('hello') == "'hello'" + assert shell_escape("it's") == "'it'\\''s'" + assert shell_escape('') == "''" + assert shell_escape('a b c') == "'a b c'" + + +def test_format_lines_empty_file(): + """format_lines on empty string returns just a newline.""" + result = format_lines('', 0, 2000) + assert result == '\n' + + +def test_format_lines_trailing_newline(): + """format_lines adds trailing newline when text doesn't end with one.""" + result = format_lines('no trailing newline', 0, 2000) + assert result.endswith('\n') + assert '1\tno trailing newline' in result + + +def test_glob_match_simple(): + assert glob_match('foo.py', '*.py') is True + assert glob_match('foo.txt', '*.py') is False + + +def test_glob_match_double_star(): + """glob_match with ** patterns for recursive matching.""" + assert glob_match('src/main.py', '**/*.py') is True + assert glob_match('deep/nested/dir/file.py', '**/*.py') is True + assert glob_match('file.py', '**/*.py') is True + assert glob_match('src/main.txt', '**/*.py') is False + + +def test_glob_match_double_star_prefix(): + """glob_match with **/ prefix.""" + assert glob_match('a/b/c.txt', '**/c.txt') is True + assert glob_match('c.txt', '**/c.txt') is True + + +def test_glob_match_double_star_suffix(): + """glob_match with ** at end.""" + assert glob_match('src/foo/bar', 'src/**') is True + + +def test_glob_match_question_mark(): + """glob_match with ? wildcard.""" + assert glob_match('test.py', 'tes?.py') is True + assert glob_match('test.py', 'te??.py') is True + assert glob_match('test.py', 't???.py') is True # t + 3 chars (est) + .py + assert glob_match('test.py', 't????.py') is False # needs 4 chars between t and .py + + +def test_build_read_file_cmd_default(): + cmd = build_read_file_cmd('test.txt') + assert 'awk' in cmd + assert "'test.txt'" in cmd + assert 'NR>=1' in cmd + assert 'NR<=2000' in cmd + + +def test_build_read_file_cmd_with_offset(): + cmd = build_read_file_cmd('file.py', offset=10, limit=50) + assert 'NR>=11' in cmd + assert 'NR<=60' in cmd + assert "'file.py'" in cmd + + +def test_build_read_file_cmd_continuation_hint(): + """build_read_file_cmd includes a continuation hint in the awk END block.""" + cmd = build_read_file_cmd('file.py', offset=0, limit=10) + assert 'more lines' in cmd + assert 'offset=10' in cmd + + +def test_build_grep_cmd_content(): + cmd = build_grep_cmd('pattern') + assert 'grep -rI' in cmd + assert '-n' in cmd + assert "'pattern'" in cmd + assert "'.'" in cmd + + +def test_build_grep_cmd_files_with_matches(): + cmd = build_grep_cmd('pat', output_mode='files_with_matches') + assert '-l' in cmd + assert '-n' not in cmd + + +def test_build_grep_cmd_count(): + cmd = build_grep_cmd('pat', output_mode='count') + assert '-c' in cmd + + +def test_build_grep_cmd_with_path(): + cmd = build_grep_cmd('pat', path='src') + assert "'src'" in cmd + + +def test_build_grep_cmd_with_glob_pattern(): + """glob_pattern is shell-escaped to prevent injection.""" + cmd = build_grep_cmd('pat', glob_pattern='*.py') + assert '--include' in cmd + assert "'*.py'" in cmd + + +def test_build_grep_cmd_glob_pattern_escaping(): + """Verify glob_pattern with special chars is properly shell-escaped.""" + cmd = build_grep_cmd('pat', glob_pattern='*.py') + # The glob pattern should be shell-escaped (wrapped in single quotes) + assert "--include '*.py'" in cmd + + # Even a malicious glob_pattern gets safely escaped + cmd2 = build_grep_cmd('pat', glob_pattern='$(evil)') + assert '$(evil)' not in cmd2.replace("'$(evil)'", '') # Only appears inside quotes + + +def test_build_glob_cmd(): + cmd = build_glob_cmd('*.py') + assert 'find' in cmd + assert "'*.py'" in cmd + assert "'.'" in cmd + + +def test_build_glob_cmd_with_path(): + cmd = build_glob_cmd('*.py', path='src') + assert "'src'" in cmd + + +def test_parse_glob_output_empty(): + assert parse_glob_output('') == [] + assert parse_glob_output(' ') == [] + assert parse_glob_output('\n') == [] + + +def test_parse_glob_output_multiline(): + assert parse_glob_output('a.py\nb.py\nc.py\n') == ['a.py', 'b.py', 'c.py'] + + +def test_filter_grep_count_output(): + text = 'a.py:3\nb.py:0\nc.py:1' + result = filter_grep_count_output(text) + assert result == 'a.py:3\nc.py:1' + + +def test_filter_grep_count_output_all_zero(): + text = 'a.py:0\nb.py:0' + result = filter_grep_count_output(text) + assert result == '' + + +def test_apply_edit_basic(): + new_text, count = apply_edit('hello world', 'world', 'earth', 'test.txt', replace_all=False) + assert new_text == 'hello earth' + assert count == 1 + + +def test_apply_edit_replace_all(): + new_text, count = apply_edit('aaa bbb aaa', 'aaa', 'xxx', 'test.txt', replace_all=True) + assert new_text == 'xxx bbb xxx' + assert count == 2 + + +def test_apply_edit_not_found(): + with pytest.raises(ValueError, match='not found'): + apply_edit('hello', 'missing', 'x', 'test.txt', replace_all=False) + + +def test_apply_edit_ambiguous(): + with pytest.raises(ValueError, match='2 times'): + apply_edit('aa bb aa', 'aa', 'x', 'test.txt', replace_all=False) + + +# --- LocalEnvironment: additional edge cases --- + + +async def test_local_execute_no_timeout(tmp_path: Path): + """execute() with timeout=None completes without timeout.""" + async with LocalEnvironment(tmp_path) as env: + result = await env.shell('echo no_timeout', timeout=None) + assert result.exit_code == 0 + assert 'no_timeout' in result.output + + +async def test_local_read_file_bytes_directory(tmp_path: Path): + """read_file_bytes on a directory raises FileNotFoundError.""" + async with LocalEnvironment(tmp_path) as env: + (tmp_path / 'adir').mkdir() + with pytest.raises(FileNotFoundError, match='is a directory'): + await env.read_file('adir') + + +async def test_local_read_file_bytes_nonexistent(tmp_path: Path): + """read_file_bytes on a nonexistent file raises FileNotFoundError.""" + async with LocalEnvironment(tmp_path) as env: + with pytest.raises(FileNotFoundError): + await env.read_file('nope.bin') + + +async def test_local_grep_specific_file(tmp_path: Path): + """grep targeting a specific file works.""" + async with LocalEnvironment(tmp_path) as env: + await env.write_file('target.py', 'findme = True\n') + await env.write_file('other.py', 'findme = False\n') + + result = await env.grep('findme', path='target.py') + assert 'target.py' in result + assert 'other.py' not in result + + +# --- MemoryEnvironment: additional edge cases --- + + +async def test_memory_normalize_paths(): + """MemoryEnvironment normalizes paths correctly.""" + async with MemoryEnvironment() as env: + await env.write_file('./test.txt', 'content') + content = await env.read_file('test.txt') + assert isinstance(content, str) + assert 'content' in content + + +async def test_memory_normalize_leading_slash(): + """MemoryEnvironment strips leading slashes.""" + async with MemoryEnvironment() as env: + await env.write_file('/test.txt', 'content') + content = await env.read_file('test.txt') + assert isinstance(content, str) + assert 'content' in content + + +async def test_memory_read_file_text(): + """read_file on text file returns formatted string.""" + env = MemoryEnvironment(files={'text.txt': 'hello'}) + async with env: + result = await env.read_file('text.txt') + assert isinstance(result, str) + assert 'hello' in result + + +async def test_memory_read_file_not_found(): + """read_file on missing file raises FileNotFoundError.""" + async with MemoryEnvironment() as env: + with pytest.raises(FileNotFoundError): + await env.read_file('missing.txt') + + +async def test_memory_edit_binary(): + """edit_file works on binary content.""" + env = MemoryEnvironment(files={'data.txt': b'hello world'}) + async with env: + count = await env.replace_str('data.txt', 'world', 'earth') + assert count == 1 + + +async def test_memory_grep_exact_path(): + """grep with path= targeting an exact file.""" + env = MemoryEnvironment( + files={ + 'src/a.py': 'target\n', + 'src/b.py': 'target\n', + } + ) + async with env: + result = await env.grep('target', path='src/a.py') + assert 'src/a.py' in result + assert 'src/b.py' not in result + + +async def test_memory_grep_no_text_content(): + """grep with text bytes (non-binary) works.""" + env = MemoryEnvironment(files={'data.txt': b'findme in bytes'}) + async with env: + result = await env.grep('findme') + assert 'data.txt' in result + + +async def test_memory_glob_recursive(): + """glob with ** pattern.""" + env = MemoryEnvironment( + files={ + 'src/a.py': '', + 'src/sub/b.py': '', + 'other.txt': '', + } + ) + async with env: + matches = await env.glob('**/*.py') + assert 'src/a.py' in matches + assert 'src/sub/b.py' in matches + assert 'other.txt' not in matches + + +async def test_memory_glob_in_subdirectory(): + """glob with path= restricts to subdirectory.""" + env = MemoryEnvironment( + files={ + 'src/a.py': '', + 'lib/b.py': '', + } + ) + async with env: + matches = await env.glob('*.py', path='src') + assert 'src/a.py' in matches + assert 'lib/b.py' not in matches + + +async def test_memory_ls_with_bytes(): + """ls reports size correctly for bytes content.""" + env = MemoryEnvironment(files={'data.bin': b'\x00\x01\x02'}) + async with env: + entries = await env.ls('.') + assert len(entries) == 1 + assert entries[0].size == 3 + assert entries[0].is_dir is False + + +# --- ExecutionEnvironmentToolset: additional coverage --- + + +async def test_toolset_bash_truncated(tmp_path: Path): + """bash tool truncation message when output exceeds limit.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + # Generate output longer than MAX_OUTPUT_CHARS (100_000) + result = await manager.handle_call( + ToolCallPart(tool_name='shell', args={'command': 'python3 -c "print(\'x\' * 200000)"'}) + ) + assert '[output truncated]' in str(result) + assert 'Exit code: 0' in str(result) + + +async def test_toolset_image_too_large(tmp_path: Path): + """read_file on an image that's too large returns error string.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env, max_image_bytes=10) # Very small limit + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + # Write a PNG file that exceeds the limit + await env.write_file('big.png', b'\x89PNG\r\n\x1a\n' + b'\x00' * 100) + result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': 'big.png'})) + assert 'Image too large' in str(result) + + +async def test_toolset_image_read(tmp_path: Path): + """read_file on an image returns BinaryContent.""" + from pydantic_ai.messages import BinaryContent + + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + png_data = ( + b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01' + b'\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89' + b'\x00\x00\x00\nIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01' + b'\r\n\xb4\x00\x00\x00\x00IEND\xaeB`\x82' + ) + await env.write_file('img.png', png_data) + result = await manager.handle_call(ToolCallPart(tool_name='read_file', args={'path': 'img.png'})) + assert isinstance(result, BinaryContent) + assert result.media_type == 'image/png' + + +async def test_toolset_grep_no_matches(tmp_path: Path): + """grep with no matches returns 'No matches found.'.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('test.txt', 'nothing relevant\n') + result = await manager.handle_call(ToolCallPart(tool_name='grep', args={'pattern': 'nonexistent_xyz'})) + assert result == snapshot('No matches found.') + + +async def test_toolset_glob_no_matches(tmp_path: Path): + """glob with no matches returns 'No files found.'.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='glob', args={'pattern': '*.nonexistent'})) + assert result == snapshot('No files found.') + + +async def test_toolset_edit_success(tmp_path: Path): + """edit_file tool returns success message.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context(None) + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + await env.write_file('code.py', 'old_value = 1\n') + result = await manager.handle_call( + ToolCallPart( + tool_name='replace_str', + args={'path': 'code.py', 'old': 'old_value', 'new': 'new_value'}, + ) + ) + assert result == snapshot('Replaced 1 occurrence in code.py.') + + +async def test_toolset_with_custom_env_instructions(): + """Environment instructions is used per-tool.""" + + class CustomEnv(MemoryEnvironment): + def instructions(self, capability: str) -> str | None: + if capability == 'grep': + return 'Custom grep description.' + return None + + env = CustomEnv() + assert env.instructions('grep') == 'Custom grep description.' + assert env.instructions('read_file') is None + + +async def test_toolset_lifecycle_ref_counting(tmp_path: Path): + """Multiple context manager entries share the environment.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + + async with toolset: + async with toolset: + # Both entries active + result = await env.shell('echo shared') + assert 'shared' in result.output + # Still alive after one exit + result = await env.shell('echo still_alive') + assert 'still_alive' in result.output + + +# --- Additional coverage: _base.py --- + + +async def test_glob_match_question_mark_in_doublestar_pattern(): + """glob_match with ? inside a ** pattern.""" + assert glob_match('a/b/test.py', '**/?est.py') is True + assert glob_match('test.py', '?est.py') is True + + +async def test_execution_environment_aenter_aexit(): + """ExecutionEnvironment base __aenter__/__aexit__ are exercised by subclasses.""" + # MemoryEnvironment exercises the base class path + env = MemoryEnvironment() + async with env: + pass + + +# --- Additional coverage: _toolset.py --- + + +async def test_toolset_bash_empty_output(tmp_path: Path): + """ExecutionEnvironmentToolset bash returns just exit code when no output.""" + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='shell', args={'command': 'true'})) + assert 'Exit code: 0' in str(result) + + +async def test_toolset_glob_truncation(tmp_path: Path): + """ExecutionEnvironmentToolset glob truncates after 100 matches.""" + env = LocalEnvironment(tmp_path) + # Create 110 files + for i in range(110): + (tmp_path / f'file_{i:03d}.txt').write_text(f'content {i}') + + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='glob', args={'pattern': '*.txt'})) + assert 'truncated' in str(result) + + +async def test_toolset_grep_no_matches_returns_message(tmp_path: Path): + """ExecutionEnvironmentToolset grep returns message when no matches.""" + (tmp_path / 'test.txt').write_text('hello world') + env = LocalEnvironment(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + manager = await ToolManager[None](toolset).for_run_step(ctx) + + async with env: + result = await manager.handle_call(ToolCallPart(tool_name='grep', args={'pattern': 'zzz_nonexistent'})) + assert 'No matches' in str(result) + + +async def test_toolset_lifecycle_error(tmp_path: Path): + """ExecutionEnvironmentToolset handles environment startup failures.""" + + class FailingEnv(LocalEnvironment): + async def __aenter__(self): + raise RuntimeError('Setup failed') + + env = FailingEnv(tmp_path) + toolset = ExecutionEnvironmentToolset(env) + with pytest.raises(RuntimeError, match='Setup failed'): + async with toolset: + pass + + +# --- Additional coverage: local.py --- + + +async def test_local_process_stdin_not_available(): + """LocalEnvironmentProcess.send raises when stdin is None.""" + from pydantic_harness.environments.local import LocalEnvironmentProcess + + mock_proc = MagicMock() + mock_proc.stdin = None + proc = LocalEnvironmentProcess(mock_proc) + with pytest.raises(RuntimeError, match='stdin'): + await proc.send(b'data') + + +async def test_local_process_stdout_not_available(): + """LocalEnvironmentProcess.recv raises when stdout is None.""" + from pydantic_harness.environments.local import LocalEnvironmentProcess + + mock_proc = MagicMock() + mock_proc.stdout = None + proc = LocalEnvironmentProcess(mock_proc) + with pytest.raises(RuntimeError, match='stdout'): + await proc.recv() + + +async def test_local_process_stderr_not_available(): + """LocalEnvironmentProcess.recv_stderr raises when stderr is None.""" + from pydantic_harness.environments.local import LocalEnvironmentProcess + + mock_proc = MagicMock() + mock_proc.stderr = None + proc = LocalEnvironmentProcess(mock_proc) + with pytest.raises(RuntimeError, match='stderr'): + await proc.recv_stderr() + + +async def test_local_process_recv_stderr_timeout(tmp_path: Path): + """LocalEnvironmentProcess.recv_stderr with timeout.""" + env = LocalEnvironment(tmp_path) + proc = await env.create_process('python -c "import sys; sys.stderr.write(\'err\\n\')"') + async with proc: + data = await proc.recv_stderr(timeout=5.0) + assert b'err' in data + + +async def test_local_process_recv_stderr_eof(tmp_path: Path): + """LocalEnvironmentProcess.recv_stderr returns empty on EOF.""" + env = LocalEnvironment(tmp_path) + proc = await env.create_process('echo done') + async with proc: + await proc.wait(timeout=5.0) + # After process exits, stderr should return empty + data = await proc.recv_stderr() + assert data == b'' + + +async def test_local_process_kill_terminates_sleep(tmp_path: Path): + """LocalEnvironmentProcess.kill terminates process.""" + env = LocalEnvironment(tmp_path) + proc = await env.create_process('sleep 60') + async with proc: + await proc.kill() + # After kill, returncode should be set + + +async def test_local_read_file_bytes_directory_raises_error(tmp_path: Path): + """LocalEnvironment.read_file_bytes raises on directory.""" + (tmp_path / 'subdir').mkdir() + env = LocalEnvironment(tmp_path) + with pytest.raises(FileNotFoundError, match='directory'): + await env.read_file('subdir') + + +async def test_local_read_file_bytes_not_found(tmp_path: Path): + """LocalEnvironment.read_file_bytes raises on missing file.""" + env = LocalEnvironment(tmp_path) + with pytest.raises(FileNotFoundError, match='not found'): + await env.read_file('nonexistent.txt') + + +async def test_local_grep_on_file(tmp_path: Path): + """LocalEnvironment.grep on a specific file path.""" + (tmp_path / 'target.py').write_text('found = True\nmissed = False\n') + env = LocalEnvironment(tmp_path) + result = await env.grep('found', path='target.py') + assert 'found' in result + assert 'missed' not in result + + +async def test_local_grep_with_glob_pattern_filters_by_extension(tmp_path: Path): + """LocalEnvironment.grep with glob filtering.""" + (tmp_path / 'a.py').write_text('match_here\n') + (tmp_path / 'b.txt').write_text('match_here\n') + env = LocalEnvironment(tmp_path) + result = await env.grep('match_here', glob_pattern='*.py') + assert 'a.py' in result + assert 'b.txt' not in result + + +async def test_local_grep_skips_binary_files_with_null_bytes(tmp_path: Path): + """LocalEnvironment.grep skips files with null bytes.""" + (tmp_path / 'binary.bin').write_bytes(b'\x00binary content') + (tmp_path / 'text.txt').write_text('searchable\n') + env = LocalEnvironment(tmp_path) + result = await env.grep('searchable') + assert 'text.txt' in result + assert 'binary' not in result + + +async def test_local_grep_skips_hidden_files_in_hidden_dirs(tmp_path: Path): + """LocalEnvironment.grep skips hidden files/dirs.""" + hidden_dir = tmp_path / '.hidden' + hidden_dir.mkdir() + (hidden_dir / 'secret.txt').write_text('findme\n') + (tmp_path / 'visible.txt').write_text('findme\n') + env = LocalEnvironment(tmp_path) + result = await env.grep('findme') + assert 'visible.txt' in result + assert '.hidden' not in result + + +# --- Base class run_python tests --- + + +async def test_base_run_python_success(): + """Base ExecutionEnvironment.run_python writes code to file and runs via shell.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _ShellEnv(BaseEnv): + """Env that records write_file/shell calls and returns canned results.""" + + written: dict[str, str | bytes] = {} + + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'shell', 'write_file', 'run_python'}) # pragma: no cover + + async def write_file(self, path: str, content: str | bytes) -> None: + self.written[path] = content + + async def shell( + self, command: str, *, timeout: float | None = 120, env: dict[str, str] | None = None + ) -> ExecutionResult: + return ExecutionResult(output='hello world\n', exit_code=0) + + env = _ShellEnv() + result = await env.run_python('print("hello world")') + assert result == 'hello world\n' + assert env.written['/tmp/_pydantic_ai_code.py'] == 'print("hello world")' + + +async def test_base_run_python_error(): + """Base ExecutionEnvironment.run_python raises CodeRuntimeError on non-zero exit.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + from pydantic_harness.toolsets.code_execution._abstract import CodeRuntimeError + + class _ShellEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'shell', 'write_file', 'run_python'}) # pragma: no cover + + async def write_file(self, path: str, content: str | bytes) -> None: + pass + + async def shell( + self, command: str, *, timeout: float | None = 120, env: dict[str, str] | None = None + ) -> ExecutionResult: + return ExecutionResult(output='Exception: fail\n', exit_code=1) + + env = _ShellEnv() + with pytest.raises(CodeRuntimeError, match='Exception: fail'): + await env.run_python('raise Exception("fail")') + + +# --- Local environment additional tests --- + + +async def test_local_execute_output_truncation(tmp_path: Path): + """LocalEnvironment.execute truncates long output.""" + script = tmp_path / 'big.py' + script.write_text("print('x' * 200000)") + env = LocalEnvironment(tmp_path) + result = await env.shell(f'python3 {script}') + assert result.truncated is True + assert len(result.output) == 100_000 + + +async def test_local_process_wait_no_timeout(tmp_path: Path): + """LocalEnvironmentProcess.wait without timeout.""" + env = LocalEnvironment(tmp_path) + proc = await env.create_process('true') + async with proc: + exit_code = await proc.wait() # no timeout + assert exit_code == 0 + + +# --- Memory environment additional tests --- + + +async def test_memory_normalize_leading_slash_in_constructor(): + """MemoryEnvironment normalizes paths with leading /.""" + env = MemoryEnvironment(files={'/abs/path.txt': 'content'}) + content = await env.read_file('abs/path.txt') + assert isinstance(content, str) + assert 'content' in content + + +async def test_memory_read_file_directory_error(): + """MemoryEnvironment.read_file raises on directory paths.""" + env = MemoryEnvironment(files={'dir/file.txt': 'content'}) + with pytest.raises(FileNotFoundError, match='directory'): + await env.read_file('dir') + + +async def test_memory_read_file_bytes_not_found_raises_error(): + """MemoryEnvironment.read_file raises on missing file.""" + env = MemoryEnvironment() + with pytest.raises(FileNotFoundError): + await env.read_file('missing.txt') + + +async def test_memory_ls_non_root_directory(): + """MemoryEnvironment.ls lists files in a subdirectory.""" + env = MemoryEnvironment(files={'sub/a.txt': 'a', 'sub/b.txt': 'b', 'other.txt': 'c'}) + entries = await env.ls('sub') + assert len(entries) == 2 + names = {e.name for e in entries} + assert names == {'a.txt', 'b.txt'} + + +async def test_memory_ls_with_subdirs(): + """MemoryEnvironment.ls shows directories in listing.""" + env = MemoryEnvironment(files={'dir/sub/file.txt': 'content'}) + entries = await env.ls('dir') + assert len(entries) == 1 + assert entries[0].name == 'sub' + assert entries[0].is_dir is True + + +async def test_memory_ls_skips_non_children(): + """MemoryEnvironment.ls skips files not under the directory.""" + env = MemoryEnvironment(files={'a/b.txt': 'x', 'c/d.txt': 'y'}) + entries = await env.ls('a') + assert len(entries) == 1 + assert entries[0].name == 'b.txt' + + +async def test_memory_grep_binary_skip(): + """MemoryEnvironment.grep skips binary files.""" + env = MemoryEnvironment(files={'binary.bin': b'\x00binary data', 'text.txt': 'findme'}) + result = await env.grep('findme') + assert 'text.txt' in result + assert 'binary' not in result + + +async def test_memory_grep_path_filter(): + """MemoryEnvironment.grep filters by exact file path.""" + env = MemoryEnvironment(files={'sub/target.py': 'match_here', 'other.py': 'match_here'}) + result = await env.grep('match_here', path='sub') + assert 'sub/target.py' in result + assert 'other.py' not in result + + +async def test_memory_glob_in_subdirectory_with_path_filter(): + """MemoryEnvironment.glob works with path parameter.""" + env = MemoryEnvironment(files={'src/a.py': 'a', 'src/b.txt': 'b', 'other.py': 'c'}) + matches = await env.glob('*.py', path='src') + assert 'src/a.py' in matches + assert 'other.py' not in matches + + +async def test_memory_normalize_absolute_path(): + """MemoryEnvironment._normalize strips leading /.""" + env = MemoryEnvironment(files={'path.txt': 'content'}) + normalized = env._normalize('/path.txt') + assert normalized == 'path.txt' + + +async def test_memory_read_file_that_is_also_directory_prefix(): + """MemoryEnvironment.read_file when path exists as both file and directory prefix.""" + env = MemoryEnvironment(files={'dir': 'I am a file', 'dir/child.txt': 'child content'}) + async with env: + content = await env.read_file('dir') + assert isinstance(content, str) + assert 'I am a file' in content + + +async def test_memory_read_image_stored_as_string(): + """MemoryEnvironment returns bytes for image files even when stored as a string.""" + env = MemoryEnvironment(files={'image.png': 'fake png data'}) + async with env: + result = await env.read_file('image.png') + assert isinstance(result, bytes) + assert result == b'fake png data' + + +# --- ExecutionEnvironmentToolset resolution tests --- + + +def test_resolve_edit_tool_explicit_strategy(): + """Passing edit_strategy to constructor overrides auto-detection.""" + env = MemoryEnvironment() + toolset = ExecutionEnvironmentToolset(env, edit_strategy='apply_patch') + strategy = toolset._resolve_edit_tool(env) + assert strategy == 'apply_patch' + + +def test_resolve_edit_tool_apply_patch_fallback(): + """When env has apply_patch but not replace_str, resolves to apply_patch.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _ApplyPatchEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'apply_patch'}) + + toolset = ExecutionEnvironmentToolset(_ApplyPatchEnv()) + strategy = toolset._resolve_edit_tool(_ApplyPatchEnv()) + assert strategy == 'apply_patch' + + +def test_resolve_edit_tool_neither(): + """When env has neither replace_str nor apply_patch, returns None.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _NoEditEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'ls'}) + + toolset = ExecutionEnvironmentToolset(_NoEditEnv()) + strategy = toolset._resolve_edit_tool(_NoEditEnv()) + assert strategy is None + + +def test_resolve_capabilities_with_run_code_with_functions(): + """Env with run_python_with_functions maps to run_code_with_functions capability.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _FunctionsEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'run_python_with_functions'}) + + toolset = ExecutionEnvironmentToolset( + _FunctionsEnv(), + exclude=frozenset(), # don't exclude run_code + ) + caps = toolset._resolve_capabilities(_FunctionsEnv()) + assert 'run_code_with_functions' in caps + + +# --- Toolset ls formatting tests --- + + +async def test_toolset_ls_formats_dirs(): + """Toolset ls formats directory entries with trailing /.""" + env = MemoryEnvironment(files={'sub/a.txt': 'hello'}) + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + tools = await toolset.get_tools(ctx) + async with env: + result = await toolset.call_tool('ls', {'path': '.'}, ctx, tools['ls']) + assert 'sub/' in str(result) + + +async def test_toolset_ls_error_handling(): + """Toolset ls returns error string when environment raises.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _ErrorLsEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'ls'}) + + async def ls(self, path: str = '.') -> list[FileInfo]: + raise NotADirectoryError(f'Not a directory: {path}') + + env = _ErrorLsEnv() + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + tools = await toolset.get_tools(ctx) + result = await toolset.call_tool('ls', {'path': '/bad'}, ctx, tools['ls']) + assert 'Error:' in str(result) + + +async def test_toolset_ls_formats_files_without_size(): + """Toolset ls formats file entries without size (just the name).""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _NoSizeEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'ls'}) + + async def ls(self, path: str = '.') -> list[FileInfo]: + return [FileInfo(name='readme.txt', path='readme.txt', is_dir=False, size=None)] + + env = _NoSizeEnv() + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + tools = await toolset.get_tools(ctx) + result = await toolset.call_tool('ls', {'path': '.'}, ctx, tools['ls']) + assert str(result) == 'readme.txt' + + +async def test_toolset_ls_empty_directory(): + """Toolset ls returns 'Empty directory.' for empty listings.""" + from pydantic_harness.environments._base import Capability as EnvCapability, ExecutionEnvironment as BaseEnv + + class _EmptyLsEnv(BaseEnv): + @property + def capabilities(self) -> frozenset[EnvCapability]: + return frozenset({'ls'}) + + async def ls(self, path: str = '.') -> list[FileInfo]: + return [] + + env = _EmptyLsEnv() + toolset = ExecutionEnvironmentToolset(env) + ctx = build_run_context() + tools = await toolset.get_tools(ctx) + result = await toolset.call_tool('ls', {'path': '.'}, ctx, tools['ls']) + assert str(result) == 'Empty directory.' + + +# --- Lazy import test --- + + +def test_lazy_import_code_execution_toolset(): + """CodeExecutionToolset is importable via pydantic_harness.toolsets.""" + from pydantic_harness.toolsets import CodeExecutionToolset + + assert CodeExecutionToolset is not None diff --git a/tests/test_python_signature.py b/tests/test_python_signature.py new file mode 100644 index 0000000..15f5958 --- /dev/null +++ b/tests/test_python_signature.py @@ -0,0 +1,1074 @@ +"""Tests for Python signature generation and deduplication.""" + +from __future__ import annotations + +from typing import Optional, Union + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel, RootModel + +from pydantic_harness._python_signature import ( + FunctionParam, + FunctionSignature, + GenericTypeExpr, + TypeFieldSignature, + TypeSignature, + UnionTypeExpr, + _annotation_to_type_expr, # pyright: ignore[reportPrivateUsage] + _to_pascal_case, # pyright: ignore[reportPrivateUsage] + collect_unique_referenced_types, + dedup_referenced_types, + function_to_signature, + render_type_expr, + schema_to_signature, +) +from pydantic_ai._run_context import RunContext + +pytestmark = pytest.mark.anyio + + +def test_dedup_referenced_types_substring_names(): + """Renaming 'User' must not corrupt 'UserMeta' in the same signature.""" + user1 = TypeSignature( + name='User', + fields={ + 'name': TypeFieldSignature(name='name', type='str', required=True, description=None), + }, + ) + user2 = TypeSignature( + name='User', + fields={ + 'id': TypeFieldSignature(name='id', type='int', required=True, description=None), + }, + ) + user_meta = TypeSignature( + name='UserMeta', + fields={ + 'role': TypeFieldSignature(name='role', type='str', required=True, description=None), + }, + ) + + sig1 = FunctionSignature( + name='tool_a', + params={'user': FunctionParam(name='user', type=user1, default=None)}, + return_type='Any', + referenced_types=[user1], + ) + sig2 = FunctionSignature( + name='tool_b', + params={ + 'user': FunctionParam(name='user', type=user2, default=None), + 'meta': FunctionParam(name='meta', type=user_meta, default=None), + }, + return_type=user_meta, + referenced_types=[user2, user_meta], + ) + dedup_referenced_types([sig1, sig2]) + + # User in sig2 was renamed to tool_b_User + assert user2.name == 'tool_b_User' + # UserMeta must be untouched + assert user_meta.name == 'UserMeta' + # Params render correctly + assert str(sig2.params['user']) == 'user: tool_b_User' + assert str(sig2.params['meta']) == 'meta: UserMeta' + assert render_type_expr(sig2.return_type) == 'UserMeta' + + +def test_dedup_identical_types_unified(): + """Identical TypeSignatures are unified to the same object instance.""" + user1 = TypeSignature( + name='User', + fields={ + 'name': TypeFieldSignature(name='name', type='str', required=True, description=None), + }, + ) + user2 = TypeSignature( + name='User', + fields={ + 'name': TypeFieldSignature(name='name', type='str', required=True, description=None), + }, + ) + + sig1 = FunctionSignature( + name='tool_a', + params={'user': FunctionParam(name='user', type=user1, default=None)}, + return_type='Any', + referenced_types=[user1], + ) + sig2 = FunctionSignature( + name='tool_b', + params={'user': FunctionParam(name='user', type=user2, default=None)}, + return_type='Any', + referenced_types=[user2], + ) + dedup_referenced_types([sig1, sig2]) + + # Both sigs keep the type, but unified to the same canonical instance + assert len(sig2.referenced_types) == 1 + assert sig2.referenced_types[0] is user1 + # sig2's param should now point to the canonical (sig1's) TypeSignature + assert sig2.params['user'].type is user1 + + # collect_unique_referenced_types emits the definition only once + defs = collect_unique_referenced_types([sig1, sig2]) + assert len(defs) == 1 + + +def test_dedup_mixed_identical_and_conflicting_from_schemas(): + """Two identical User $defs are unified; a third different User is renamed. + + Tests the full pipeline: JSON schema -> schema_to_signature -> dedup -> render. + """ + # tool_a and tool_b both have a User $def with {name: str} + user_v1_def = { + 'type': 'object', + 'properties': {'name': {'type': 'string'}}, + 'required': ['name'], + } + # tool_c has a User $def with {id: int} -- same name, different structure + user_v2_def = { + 'type': 'object', + 'properties': {'id': {'type': 'integer'}}, + 'required': ['id'], + } + + sig1 = schema_to_signature( + 'tool_a', + { + 'type': 'object', + 'properties': {'user': {'$ref': '#/$defs/User'}}, + 'required': ['user'], + '$defs': {'User': user_v1_def}, + }, + ) + sig2 = schema_to_signature( + 'tool_b', + { + 'type': 'object', + 'properties': {'user': {'$ref': '#/$defs/User'}}, + 'required': ['user'], + '$defs': {'User': user_v1_def}, + }, + ) + sig3 = schema_to_signature( + 'tool_c', + { + 'type': 'object', + 'properties': {'user': {'$ref': '#/$defs/User'}}, + 'required': ['user'], + '$defs': {'User': user_v2_def}, + }, + ) + + dedup_referenced_types([sig1, sig2, sig3]) + + # sig1 and sig2 share the same canonical User instance + assert len(sig1.referenced_types) == 1 + assert len(sig2.referenced_types) == 1 + assert sig2.referenced_types[0] is sig1.referenced_types[0] + assert sig2.params['user'].type is sig1.referenced_types[0] + + # sig3's User was renamed to tool_c_User + assert sig3.referenced_types[0].name == 'tool_c_User' + assert str(sig3.params['user']) == 'user: tool_c_User' + + # collect_unique_referenced_types returns exactly 2 unique types + unique_types = collect_unique_referenced_types([sig1, sig2, sig3]) + assert [str(t) for t in unique_types] == snapshot( + [ + """\ +class User(TypedDict): + name: str""", + """\ +class tool_c_User(TypedDict): + id: int""", + ] + ) + + +def test_dedup_composite_type_expr_rename_propagates(): + """Renaming propagates through GenericTypeExpr references like list[User].""" + user1 = TypeSignature( + name='User', + fields={ + 'name': TypeFieldSignature(name='name', type='str', required=True, description=None), + }, + ) + user2 = TypeSignature( + name='User', + fields={ + 'id': TypeFieldSignature(name='id', type='int', required=True, description=None), + }, + ) + + sig1 = FunctionSignature( + name='tool_a', + params={'user': FunctionParam(name='user', type=user1, default=None)}, + return_type='Any', + referenced_types=[user1], + ) + # sig2 references user2 via list[User] + sig2 = FunctionSignature( + name='tool_b', + params={'users': FunctionParam(name='users', type=GenericTypeExpr(base='list', args=[user2]), default=None)}, + return_type='Any', + referenced_types=[user2], + ) + dedup_referenced_types([sig1, sig2]) + + # user2 was renamed in place + assert user2.name == 'tool_b_User' + # The list[User] now renders as list[tool_b_User] + assert str(sig2.params['users']) == 'users: list[tool_b_User]' + + +def test_render_type_signature(): + """TypeSignature renders a valid TypedDict class definition.""" + ts = TypeSignature( + name='User', + fields={ + 'name': TypeFieldSignature(name='name', type='str', required=True, description=None), + 'age': TypeFieldSignature(name='age', type='int', required=False, description='The age'), + }, + ) + assert str(ts) == snapshot("""\ +class User(TypedDict): + name: str + age: NotRequired[int] + \"\"\"The age\"\"\"\ +""") + + +def test_render_type_signature_empty(): + """Empty TypeSignature renders with pass.""" + ts = TypeSignature(name='Empty') + assert str(ts) == 'class Empty(TypedDict):\n pass' + + +def test_render_generic_type_expr(): + """GenericTypeExpr renders correctly.""" + user = TypeSignature(name='User') + expr = GenericTypeExpr(base='list', args=[user]) + assert str(expr) == 'list[User]' + + dict_expr = GenericTypeExpr(base='dict', args=['str', GenericTypeExpr(base='list', args=[user])]) + assert str(dict_expr) == 'dict[str, list[User]]' + + +def test_render_union_type_expr(): + """UnionTypeExpr renders correctly.""" + user = TypeSignature(name='User') + expr = UnionTypeExpr(members=[user, 'None']) + assert str(expr) == 'User | None' + + +def test_render_function_param(): + """FunctionParam renders correct parameter strings.""" + p1 = FunctionParam(name='x', type='int', default=None) + assert str(p1) == 'x: int' + + p2 = FunctionParam(name='y', type='str', default="'hello'") + assert str(p2) == "y: str = 'hello'" + + user = TypeSignature(name='User') + p3 = FunctionParam(name='user', type=user, default=None) + assert str(p3) == 'user: User' + + +def test_structurally_equal(): + """TypeSignature.structurally_equal compares fields structurally.""" + ts1 = TypeSignature( + name='A', + fields={ + 'x': TypeFieldSignature(name='x', type='int', required=True, description='desc1'), + }, + ) + ts2 = TypeSignature( + name='B', + fields={ + 'x': TypeFieldSignature(name='x', type='int', required=True, description='different desc'), + }, + ) + # Same fields, different names and descriptions -- structurally equal + assert ts1.structurally_equal(ts2) + + ts3 = TypeSignature( + name='C', + fields={ + 'x': TypeFieldSignature(name='x', type='str', required=True, description=None), + }, + ) + # Different field type -- not equal + assert not ts1.structurally_equal(ts3) + + +def test_to_pascal_case_digit_prefix(): + """PascalCase of a name starting with digits gets a leading underscore.""" + assert _to_pascal_case('123_tool') == '_123Tool' + + +def test_to_pascal_case_edge_cases(): + """Edge cases: empty string, all-digits, hyphenated.""" + assert _to_pascal_case('') == '' + assert _to_pascal_case('42') == '_42' + assert _to_pascal_case('my-tool-name') == 'MyToolName' + + + +# NOTE: Tests for FunctionToolDefinition.python_signature removed — +# FunctionToolDefinition is from PR #4755, not yet in pydantic-ai-slim. +# TODO: Re-add when PR #4755 lands. + + +# ============================================================================= +# Function signature edge cases +# ============================================================================= + + +def test_function_signature_special_params(): + """RunContext skipped, unannotated -> Any.""" + + def with_ctx(ctx: RunContext[None], x: int) -> int: + return x # pragma: no cover + + assert str(function_to_signature(with_ctx, name='with_ctx')) == snapshot("""\ +async def with_ctx(*, x: int) -> int: + ...\ +""") + + def no_annot(x): # pyright: ignore[reportUnknownParameterType,reportMissingParameterType] + return x # pragma: no cover # pyright: ignore[reportUnknownVariableType] + + assert str( + function_to_signature(no_annot, name='no_annot') # pyright: ignore[reportUnknownArgumentType] + ) == snapshot("""\ +async def no_annot(*, x: Any) -> Any: + ...\ +""") + + +class _UserInfo(BaseModel): + name: str + + +class _IntList(RootModel[list[int]]): + pass + + +class _TreeNode(BaseModel): + value: int + children: list[_TreeNode] = [] + + +def test_function_signature_union_and_model_types(): + """Unions, optionals, and model types render correct signatures.""" + + def complex_func( + a: Union[int, str], # noqa: UP007 -- testing Union[] code path + b: int | str, + c: Optional[int] = None, # noqa: UP045 -- testing Optional[] code path + d: _UserInfo | None = None, + ) -> _UserInfo: ... # pragma: no cover + + sig = function_to_signature(complex_func, name='complex_func') + assert str(sig) == snapshot("""\ +async def complex_func(*, a: int | str, b: int | str, c: int | None = None, d: _UserInfo | None = None) -> _UserInfo: + ...\ +""") + assert [str(t) for t in sig.referenced_types] == snapshot( + [ + """\ +class _UserInfo(TypedDict): + name: str\ +""" + ] + ) + + +def test_type_signature_docstring_and_structural_equality(): + """Docstring rendering and structural equality with different required.""" + ts = TypeSignature(name='Documented', docstring='A documented empty type') + assert str(ts) == snapshot('''\ +class Documented(TypedDict): + """A documented empty type"""\ +''') + + ts_a = TypeSignature( + name='A', + fields={'x': TypeFieldSignature(name='x', type='int', required=True, description=None)}, + ) + ts_b = TypeSignature( + name='B', + fields={'x': TypeFieldSignature(name='x', type='int', required=False, description=None)}, + ) + assert not ts_a.structurally_equal(ts_b) + + +# ============================================================================= +# Schema signature edge cases +# ============================================================================= + + +def test_schema_signature_const_enum(): + """const and enum paths in schema_to_type_expr produce Literal types.""" + # const value + sig_const = schema_to_signature( + 'tool_const', + { + 'type': 'object', + 'properties': {'mode': {'const': 'fast'}}, + 'required': ['mode'], + }, + ) + assert str(sig_const) == snapshot("""\ +async def tool_const(*, mode: Literal['fast']) -> Any: + ...\ +""") + + # enum values + sig_enum = schema_to_signature( + 'tool_enum', + { + 'type': 'object', + 'properties': {'color': {'enum': ['red', 'green', 'blue']}}, + 'required': ['color'], + }, + ) + assert str(sig_enum) == snapshot("""\ +async def tool_enum(*, color: Literal['red', 'green', 'blue']) -> Any: + ...\ +""") + + +def test_collect_unique_referenced_types_empty(): + """Empty input returns empty list.""" + assert collect_unique_referenced_types([]) == [] + + sig = FunctionSignature(name='no_refs', params={}, return_type='Any', referenced_types=[]) + assert collect_unique_referenced_types([sig]) == [] + + +def test_schema_signature_union_ref_allof(): + """oneOf, allOf, $ref variants produce correct signatures.""" + sig_oneof = schema_to_signature( + 'my_tool', + { + 'type': 'object', + 'properties': {'value': {'oneOf': [{'type': 'string'}, {'type': 'integer'}]}}, + 'required': ['value'], + }, + ) + assert str(sig_oneof) == snapshot("""\ +async def my_tool(*, value: str | int) -> Any: + ...\ +""") + + sig_allof_single = schema_to_signature( + 'tool2', + { + 'type': 'object', + 'properties': {'x': {'allOf': [{'type': 'string'}]}}, + 'required': ['x'], + }, + ) + assert str(sig_allof_single) == snapshot("""\ +async def tool2(*, x: str) -> Any: + ...\ +""") + + sig_allof_multi = schema_to_signature( + 'tool3', + { + 'type': 'object', + 'properties': {'x': {'allOf': [{'type': 'string'}, {'type': 'integer'}]}}, + 'required': ['x'], + }, + ) + assert str(sig_allof_multi) == snapshot("""\ +async def tool3(*, x: Any) -> Any: + ...\ +""") + + sig_ref = schema_to_signature( + 'tool4', + { + 'type': 'object', + 'properties': {'user': {'$ref': '#/$defs/User'}}, + 'required': ['user'], + '$defs': {'User': {'type': 'object', 'properties': {'name': {'type': 'string'}}, 'required': ['name']}}, + }, + ) + assert str(sig_ref) == snapshot("""\ +async def tool4(*, user: User) -> Any: + ...\ +""") + assert [str(t) for t in sig_ref.referenced_types] == snapshot( + [ + """\ +class User(TypedDict): + name: str\ +""" + ] + ) + + sig_ref_nonobj = schema_to_signature( + 'tool5', + { + 'type': 'object', + 'properties': {'x': {'$ref': '#/$defs/StringAlias'}}, + 'required': ['x'], + '$defs': {'StringAlias': {'type': 'string'}}, + }, + ) + assert str(sig_ref_nonobj) == snapshot("""\ +async def tool5(*, x: StringAlias) -> Any: + ...\ +""") + + +def test_schema_signature_array_object_typelist(): + """Arrays, objects, additionalProperties, and type lists.""" + # Tuple array + assert str( + schema_to_signature( + 't1', + { + 'type': 'object', + 'properties': {'coords': {'type': 'array', 'items': [{'type': 'number'}, {'type': 'number'}]}}, + 'required': ['coords'], + }, + ) + ) == snapshot("""\ +async def t1(*, coords: tuple[float, float]) -> Any: + ...\ +""") + + # Empty array + assert str( + schema_to_signature( + 't2', + { + 'type': 'object', + 'properties': {'data': {'type': 'array'}}, + 'required': ['data'], + }, + ) + ) == snapshot("""\ +async def t2(*, data: list[Any]) -> Any: + ...\ +""") + + # additionalProperties: true + assert str( + schema_to_signature( + 't3', + { + 'type': 'object', + 'properties': {'meta': {'type': 'object', 'additionalProperties': True}}, + 'required': ['meta'], + }, + ) + ) == snapshot("""\ +async def t3(*, meta: dict[str, Any]) -> Any: + ...\ +""") + + # Typed additionalProperties + assert str( + schema_to_signature( + 't4', + { + 'type': 'object', + 'properties': {'tags': {'type': 'object', 'additionalProperties': {'type': 'string'}}}, + 'required': ['tags'], + }, + ) + ) == snapshot("""\ +async def t4(*, tags: dict[str, str]) -> Any: + ...\ +""") + + # Type list ['string', 'null'] + assert str( + schema_to_signature( + 't5', + { + 'type': 'object', + 'properties': {'name': {'type': ['string', 'null']}}, + 'required': ['name'], + }, + ) + ) == snapshot("""\ +async def t5(*, name: str | None) -> Any: + ...\ +""") + + # Type list multi + assert str( + schema_to_signature( + 't6', + { + 'type': 'object', + 'properties': {'value': {'type': ['string', 'integer', 'boolean']}}, + 'required': ['value'], + }, + ) + ) == snapshot("""\ +async def t6(*, value: str | int | bool) -> Any: + ...\ +""") + + # Object type list with null + sig = schema_to_signature( + 't7', + { + 'type': 'object', + 'properties': { + 'config': { + 'type': ['object', 'null'], + 'properties': {'enabled': {'type': 'boolean'}}, + 'required': ['enabled'], + }, + }, + 'required': ['config'], + }, + ) + assert str(sig) == snapshot("""\ +async def t7(*, config: T7Config | None) -> Any: + ...\ +""") + assert [str(t) for t in sig.referenced_types] == snapshot( + [ + """\ +class T7Config(TypedDict): + enabled: bool\ +""" + ] + ) + + +def test_schema_signature_optional_params_and_return(): + """Optional params, return schema edge cases, anyOf dedup.""" + # Optional already nullable + assert str( + schema_to_signature( + 't1', + { + 'type': 'object', + 'properties': {'x': {'type': ['string', 'null']}}, + }, + ) + ) == snapshot("""\ +async def t1(*, x: str | None = None) -> Any: + ...\ +""") + + # Optional not nullable -> adds | None + assert str( + schema_to_signature( + 't2', + { + 'type': 'object', + 'properties': {'x': {'type': 'string'}}, + }, + ) + ) == snapshot("""\ +async def t2(*, x: str | None = None) -> Any: + ...\ +""") + + # Unresolvable return_schema -> JSON blob in description + sig3 = schema_to_signature( + 't3', + {'type': 'object', 'properties': {'x': {'type': 'string'}}, 'required': ['x']}, + description='A tool', + return_schema={}, + ) + assert str(sig3) == snapshot('''\ +async def t3(*, x: str) -> Any: + """ + A tool + + Return schema: + {} + """ + ...\ +''') + + # Return schema with $defs + sig4 = schema_to_signature( + 't4', + {'type': 'object', 'properties': {'x': {'type': 'string'}}, 'required': ['x']}, + return_schema={ + '$ref': '#/$defs/Result', + '$defs': {'Result': {'type': 'object', 'properties': {'v': {'type': 'integer'}}, 'required': ['v']}}, + }, + ) + assert str(sig4) == snapshot("""\ +async def t4(*, x: str) -> Result: + ...\ +""") + assert [str(t) for t in sig4.referenced_types] == snapshot( + [ + """\ +class Result(TypedDict): + v: int\ +""" + ] + ) + + # anyOf with duplicates -> deduplicated + assert str( + schema_to_signature( + 't5', + { + 'type': 'object', + 'properties': {'x': {'anyOf': [{'type': 'string'}, {'type': 'string'}, {'type': 'null'}]}}, + 'required': ['x'], + }, + ) + ) == snapshot("""\ +async def t5(*, x: str | None) -> Any: + ...\ +""") + + +# ============================================================================= +# Additional coverage tests +# ============================================================================= + + + +# NOTE: test_function_tool_definition_eq_non_tool removed — depends on FunctionToolDefinition (PR #4755) + + +def test_get_type_name_repr_fallback(): + """Types without __name__ use repr fallback, NoneType returns 'None'.""" + from pydantic_harness._python_signature import _get_type_name # pyright: ignore[reportPrivateUsage] + + # NoneType returns 'None' + assert _get_type_name(type(None)) == 'None' + + # Literal type args (e.g. the string 'a') have no __name__ -> repr fallback + assert _get_type_name('some_value') == "'some_value'" + + +def test_function_signature_literal_annotation(): + """Literal type annotations exercise the repr fallback in _get_type_name.""" + import typing + + ns: dict[str, object] = {'typing': typing} + exec("def func(x: typing.Literal['a', 'b']) -> None: ...", ns) + sig = function_to_signature(ns['func'], name='func') # pyright: ignore[reportArgumentType] + assert "Literal['a', 'b']" in str(sig) + + +def test_annotation_to_type_expr_bare_generic(): + """Bare generic (origin but no type args) returns the origin's type name.""" + import typing + + # typing.List has __origin__=list but no __args__ + result = _annotation_to_type_expr(typing.List, {}) # noqa: UP006 + assert result == 'list' + + +def test_function_signature_nameerror_fallback(): + """Functions with unresolvable forward refs fall back to empty type hints.""" + ns: dict[str, object] = {} + exec( + "def func_with_fwd_ref(x: 'NonexistentType') -> None: ...", + ns, + ) + func = ns['func_with_fwd_ref'] + sig = function_to_signature(func, name='func_fwd') # pyright: ignore[reportArgumentType] + # Should not raise -- falls back to empty hints, x becomes Any + assert 'x' in sig.params + + +def test_schema_allows_null_anyof(): + """_schema_allows_null detects null in anyOf.""" + from pydantic_harness._python_signature import _schema_allows_null # pyright: ignore[reportPrivateUsage] + + assert _schema_allows_null({'anyOf': [{'type': 'string'}, {'type': 'null'}]}) is True + assert _schema_allows_null({'anyOf': [{'type': 'string'}]}) is False + + +def test_schema_bare_object(): + """Bare object type (no properties, no additionalProperties) renders as dict.""" + sig = schema_to_signature( + 't_bare', + { + 'type': 'object', + 'properties': {'data': {'type': 'object'}}, + 'required': ['data'], + }, + ) + assert 'dict[str, Any]' in str(sig) + + +def test_schema_object_type_list_no_null(): + """Object in type list without null renders without union.""" + sig = schema_to_signature( + 't_obj_list', + { + 'type': 'object', + 'properties': { + 'config': { + 'type': ['object'], + 'properties': {'x': {'type': 'string'}}, + 'required': ['x'], + }, + }, + 'required': ['config'], + }, + ) + rendered = str(sig) + assert 'None' not in rendered + assert 'TObjListConfig' in rendered + + +def test_schema_single_type_after_filtering(): + """Single-element type list renders as plain type.""" + sig = schema_to_signature( + 't_single', + { + 'type': 'object', + 'properties': {'x': {'type': ['string']}}, + 'required': ['x'], + }, + ) + assert 'str' in str(sig) + # Should not be a union + assert '|' not in str(sig) + + +def test_schema_single_anyof_member(): + """Single-member anyOf returns the type directly, not a union.""" + sig = schema_to_signature( + 't_anyof_single', + { + 'type': 'object', + 'properties': {'x': {'anyOf': [{'type': 'string'}]}}, + 'required': ['x'], + }, + ) + rendered = str(sig) + assert 'str' in rendered + assert '|' not in rendered + + +def test_schema_defs_already_processed(): + """Second call with same $defs name finds it already processed.""" + sig = schema_to_signature( + 'tool_shared', + { + 'type': 'object', + 'properties': { + 'a': {'$ref': '#/$defs/Shared'}, + 'b': {'$ref': '#/$defs/Shared'}, + }, + 'required': ['a', 'b'], + '$defs': { + 'Shared': { + 'type': 'object', + 'properties': {'v': {'type': 'integer'}}, + 'required': ['v'], + } + }, + }, + ) + # Both params reference the same TypeSignature + assert render_type_expr(sig.params['a'].type) == 'Shared' + assert render_type_expr(sig.params['b'].type) == 'Shared' + # Only one referenced type + assert len(sig.referenced_types) == 1 + + +def test_schema_return_type_dedup(): + """Param and return schemas sharing a $defs type produce one definition.""" + shared_def = { + 'type': 'object', + 'properties': {'id': {'type': 'integer'}}, + 'required': ['id'], + } + sig = schema_to_signature( + 'tool_dedup', + { + 'type': 'object', + 'properties': {'item': {'$ref': '#/$defs/Item'}}, + 'required': ['item'], + '$defs': {'Item': shared_def}, + }, + return_schema={ + '$ref': '#/$defs/Item', + '$defs': {'Item': shared_def}, + }, + ) + # Both param and return reference Item + assert render_type_expr(sig.params['item'].type) == 'Item' + assert render_type_expr(sig.return_type) == 'Item' + + +def test_schema_object_type_name_collision(): + """Two properties generating the same path-based type name -- second is reused.""" + sig = schema_to_signature( + 'tool_collision', + { + 'type': 'object', + 'properties': { + 'data': { + 'type': 'object', + 'properties': {'x': {'type': 'string'}}, + 'required': ['x'], + }, + }, + 'required': ['data'], + }, + ) + # Should have a TypedDict for the nested object + assert len(sig.referenced_types) == 1 + assert sig.referenced_types[0].name == 'ToolCollisionData' + + +def test_function_signature_root_model(): + """RootModel wrapping a non-object type renders as the type name without a TypedDict.""" + + def func_with_root_model(x: _IntList) -> None: ... # pragma: no cover + + sig = function_to_signature(func_with_root_model, name='func_with_root_model') + # RootModel produces a non-object schema, so it's referenced by name but not as a TypedDict + assert str(sig) == snapshot("""\ +async def func_with_root_model(*, x: _IntList) -> None: + ...\ +""") + assert sig.referenced_types == [] + + +def test_function_signature_recursive_model(): + """Recursive BaseModel with top-level $ref in schema produces correct TypedDict.""" + + def func_with_tree(x: _TreeNode) -> None: ... # pragma: no cover + + sig = function_to_signature(func_with_tree, name='func_with_tree') + assert str(sig) == snapshot("""\ +async def func_with_tree(*, x: _TreeNode) -> None: + ...\ +""") + assert len(sig.referenced_types) == 1 + assert sig.referenced_types[0].name == '_TreeNode' + assert 'value' in sig.referenced_types[0].fields + assert 'children' in sig.referenced_types[0].fields + + +def test_schema_cross_referencing_defs(): + """$ref in a def property lazily resolves another def not yet processed.""" + sig = schema_to_signature( + 'tool', + { + 'type': 'object', + 'properties': {'item': {'$ref': '#/$defs/Container'}}, + 'required': ['item'], + '$defs': { + 'Container': { + 'type': 'object', + 'properties': {'inner': {'$ref': '#/$defs/Inner'}}, + 'required': ['inner'], + }, + 'Inner': { + 'type': 'object', + 'properties': {'value': {'type': 'string'}}, + 'required': ['value'], + }, + }, + }, + ) + assert str(sig) == snapshot("""\ +async def tool(*, item: Container) -> Any: + ...\ +""") + # Both Container and Inner should be resolved as TypeSignatures + type_names = {t.name for t in sig.referenced_types} + assert type_names == {'Container', 'Inner'} + # Container's 'inner' field references the Inner TypeSignature + container = next(t for t in sig.referenced_types if t.name == 'Container') + assert render_type_expr(container.fields['inner'].type) == 'Inner' + + +def test_collect_referenced_types_skips_already_registered(): + """When a $def name is already in referenced_types, _collect_referenced_types skips it.""" + from pydantic_harness._python_signature import _collect_referenced_types # pyright: ignore[reportPrivateUsage] + + class _Address(BaseModel): + street: str + + class _Person(BaseModel): + name: str + home: _Address + + # Processing Address first registers it + referenced_types: dict[str, TypeSignature] = {} + _collect_referenced_types(_Address, referenced_types, 'func', 'a') + assert '_Address' in referenced_types + + # Now collect Person -- its $defs include _Address which should be skipped + _collect_referenced_types(_Person, referenced_types, 'func', 'b') + assert '_Person' in referenced_types + + +def test_schema_inline_object_reuses_existing_typename(): + """Inline object property reuses a $defs type when path-based names collide.""" + sig = schema_to_signature( + 'tool', + { + 'type': 'object', + 'properties': { + 'data': { + 'type': 'object', + 'properties': {'x': {'type': 'string'}}, + 'required': ['x'], + }, + }, + 'required': ['data'], + '$defs': { + # This $def's name matches the path-based typename for property 'data' + # _path_to_typename('tool', 'data') == 'ToolData' + 'ToolData': { + 'type': 'object', + 'properties': {'y': {'type': 'integer'}}, + 'required': ['y'], + }, + }, + }, + ) + # The $defs ToolData was registered first; the inline 'data' property reuses it + assert render_type_expr(sig.params['data'].type) == 'ToolData' + + +def test_schema_additional_properties_false(): + """additionalProperties: false falls through to dict[str, Any].""" + sig = schema_to_signature( + 't_ap_false', + { + 'type': 'object', + 'properties': { + 'meta': {'type': 'object', 'additionalProperties': False}, + }, + 'required': ['meta'], + }, + ) + assert 'dict[str, Any]' in str(sig) + + +def test_schema_empty_type_list(): + """Empty type list produces Any.""" + sig = schema_to_signature( + 't_empty', + { + 'type': 'object', + 'properties': {'x': {'type': []}}, + 'required': ['x'], + }, + ) + assert 'Any' in str(sig) diff --git a/uv.lock b/uv.lock index 8a1e193..7059ac0 100644 --- a/uv.lock +++ b/uv.lock @@ -25,6 +25,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "asttokens" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/a5/8e3f9b6771b0b408517c82d97aed8f2036509bc247d46114925e32fe33f0/asttokens-3.0.1.tar.gz", hash = "sha256:71a4ee5de0bde6a31d64f6b13f2293ac190344478f081c3d1bccfcf5eacb0cb7", size = 62308, upload-time = "2025-11-15T16:43:48.578Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -156,6 +165,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, ] +[[package]] +name = "dirty-equals" +version = "0.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/1d/c5913ac9d6615515a00f4bdc71356d302437cb74ff2e9aaccd3c14493b78/dirty_equals-0.11.tar.gz", hash = "sha256:f4ac74ee88f2d11e2fa0f65eb30ee4f07105c5f86f4dc92b09eb1138775027c3", size = 128067, upload-time = "2025-11-17T01:51:24.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/8d/dbff05239043271dbeace563a7686212a3dd517864a35623fe4d4a64ca19/dirty_equals-0.11-py3-none-any.whl", hash = "sha256:b1d7093273fc2f9be12f443a8ead954ef6daaf6746fd42ef3a5616433ee85286", size = 28051, upload-time = "2025-11-17T01:51:22.849Z" }, +] + [[package]] name = "exceptiongroup" version = "1.3.1" @@ -177,6 +195,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, ] +[[package]] +name = "executing" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, +] + [[package]] name = "genai-prices" version = "0.0.56" @@ -194,6 +221,7 @@ wheels = [ name = "griffelib" version = "2.0.0" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/06/eccbd311c9e2b3ca45dbc063b93134c57a1ccc7607c5e545264ad092c4a9/griffelib-2.0.0.tar.gz", hash = "sha256:e504d637a089f5cab9b5daf18f7645970509bf4f53eda8d79ed71cce8bd97934", size = 166312, upload-time = "2026-03-23T21:06:55.954Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, ] @@ -265,6 +293,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "inline-snapshot" +version = "0.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pytest" }, + { name = "rich" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/87/62b78b49042c533038ab1bf0931a7b70fdb78d07a11c9bf159be04027df8/inline_snapshot-0.32.5.tar.gz", hash = "sha256:5025074eab5c82a88504975e2655beeb5e96fd57ed2d9ebb38538473748f2065", size = 2626796, upload-time = "2026-03-13T18:35:54.891Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/d3/73426dd3da75095fd071ce5c1f8e520e879a582ca04df861575c6feb9166/inline_snapshot-0.32.5-py3-none-any.whl", hash = "sha256:ac617c273e811ed5ca15abd8f8dbd3fa268296bb0642ccb1403a5df61ce2e39e", size = 84993, upload-time = "2026-03-13T18:35:52.955Z" }, +] + [[package]] name = "logfire-api" version = "4.29.0" @@ -274,6 +319,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/cc/62df4abc3e4650c25b81a8e39a1d498d3246c43f3aa4bfab7a73689317b4/logfire_api-4.29.0-py3-none-any.whl", hash = "sha256:48a1361b818357f5a37c71f9683f97e626e5df6c17f35212bfc1f19dddc6771c", size = 121457, upload-time = "2026-03-13T15:30:22.652Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5b/f5/4ec618ed16cc4f8fb3b701563655a69816155e79e24a17b651541804721d/markdown_it_py-4.0.0.tar.gz", hash = "sha256:cb0a2b4aa34f932c007117b194e945bd74e0ec24133ceb5bac59009cda1cb9f3", size = 73070, upload-time = "2025-08-11T12:57:52.854Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "nodeenv" version = "1.10.0" @@ -331,7 +397,7 @@ wheels = [ [[package]] name = "pydantic-ai-slim" -version = "1.70.0" +version = "1.71.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, @@ -343,9 +409,9 @@ dependencies = [ { name = "pydantic-graph" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ac/97/d57ee44976c349658ea7c645c5c2e1a26830e4b60fdeeee2669d4aaef6eb/pydantic_ai_slim-1.70.0.tar.gz", hash = "sha256:3df0c0e92f72c35e546d24795bce1f4d38f81da2d10addd2e9f255b2d2c83c91", size = 445474, upload-time = "2026-03-18T04:24:34.393Z" } +sdist = { url = "https://files.pythonhosted.org/packages/57/75/09845db7133d4287d16b08a74e7c3db9e8eb8524551de7d3397ed3c60034/pydantic_ai_slim-1.71.0.tar.gz", hash = "sha256:b2d097d4c9f56c820cd2cefbba923dd7d2eb6307eb966cfe1810fb9ba970a43a", size = 493584, upload-time = "2026-03-24T22:00:57.003Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/8c/8545d28d0b3a9957aa21393cfdab8280bb854362360b296cd486ed1713ec/pydantic_ai_slim-1.70.0-py3-none-any.whl", hash = "sha256:162907092a562b3160d9ef0418d317ec941c5c0e6dd6e0aa0dbb53b5a5cd3450", size = 576244, upload-time = "2026-03-18T04:24:27.301Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c4/fe5a11e42e7a8e4b92035ac86cb12610fa38ffc7d87350d97f60252073fd/pydantic_ai_slim-1.71.0-py3-none-any.whl", hash = "sha256:02f21e39ca5d809a0a58ad064b4513af6ac8e4d143096ea207281d96e16f2633", size = 634699, upload-time = "2026-03-24T22:00:48.582Z" }, ] [[package]] @@ -468,7 +534,7 @@ wheels = [ [[package]] name = "pydantic-graph" -version = "1.70.0" +version = "1.71.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -476,22 +542,34 @@ dependencies = [ { name = "pydantic" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/27/f7a71ca2a3705e7c24fd777959cf5515646cc5f23b5b16c886a2ed373340/pydantic_graph-1.70.0.tar.gz", hash = "sha256:3f76d9137369ef8748b0e8a6df1a08262118af20a32bc139d23e5c0509c6b711", size = 58578, upload-time = "2026-03-18T04:24:37.007Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/85/1c9608b3fa30bca0cd28b513d7dd7e1ac176580d3105d6c415b20b766f11/pydantic_graph-1.71.0.tar.gz", hash = "sha256:493672a71706d2ffa091a8c2ee3361ce575e93eb5f03b0b56bfc4f23fb5715f7", size = 58718, upload-time = "2026-03-24T22:00:59.87Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/38/fd/19c42b60c37dfdbbf5b76c7b218e8309b43dac501f7aaf2025527ca05023/pydantic_graph-1.70.0-py3-none-any.whl", hash = "sha256:6083c1503a2587990ee1b8a15915106e3ddabc8f3f11fbc4a108a7d7496af4a5", size = 72351, upload-time = "2026-03-18T04:24:30.291Z" }, + { url = "https://files.pythonhosted.org/packages/e5/59/54fc89ae1f52f9cf07c82d8cc6e4752cc46abfdeb979b902e07193392879/pydantic_graph-1.71.0-py3-none-any.whl", hash = "sha256:97991c7433f1791062330691adffc0e47f65a4b54bfc124914110bff6ce66d62", size = 72503, upload-time = "2026-03-24T22:00:52.125Z" }, ] [[package]] name = "pydantic-harness" source = { editable = "." } dependencies = [ + { name = "anyio" }, + { name = "pydantic" }, { name = "pydantic-ai-slim" }, ] +[package.optional-dependencies] +monty = [ + { name = "pydantic-monty" }, +] + [package.dev-dependencies] dev = [ + { name = "anyio" }, { name = "coverage" }, + { name = "dirty-equals" }, + { name = "inline-snapshot" }, + { name = "pydantic-monty" }, { name = "pytest" }, + { name = "pytest-mock" }, { name = "pytest-xdist" }, ] lint = [ @@ -500,12 +578,23 @@ lint = [ ] [package.metadata] -requires-dist = [{ name = "pydantic-ai-slim", specifier = ">=0.1" }] +requires-dist = [ + { name = "anyio", specifier = ">=4.5.0" }, + { name = "pydantic", specifier = ">=2.12" }, + { name = "pydantic-ai-slim", specifier = ">=1.71.0" }, + { name = "pydantic-monty", marker = "extra == 'monty'", specifier = ">=0.0.5" }, +] +provides-extras = ["monty"] [package.metadata.requires-dev] dev = [ + { name = "anyio", specifier = ">=4.5.0" }, { name = "coverage" }, - { name = "pytest" }, + { name = "dirty-equals", specifier = ">=0.9.0" }, + { name = "inline-snapshot", specifier = ">=0.19.3" }, + { name = "pydantic-monty", specifier = ">=0.0.5" }, + { name = "pytest", specifier = ">=9.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, { name = "pytest-xdist" }, ] lint = [ @@ -513,6 +602,74 @@ lint = [ { name = "ruff", specifier = ">=0.14" }, ] +[[package]] +name = "pydantic-monty" +version = "0.0.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/bf/e9794b562c207406d8fda0cf4fea810943a5e8a85fe69e5505046179df16/pydantic_monty-0.0.8.tar.gz", hash = "sha256:8135e781a184f971825c1d2eb6d621598103e900f6e0d34291ff0bf35df6142f", size = 802646, upload-time = "2026-03-10T14:46:51.353Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/5a/1d80f5b27717cb5842d9140b8f4033c61340a0f24588cdd06e8c1d71ec57/pydantic_monty-0.0.8-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:1a15033d8479adf4566fa52f9067640a83e3478340757f480c9924032386a9a3", size = 6700740, upload-time = "2026-03-10T14:46:17.612Z" }, + { url = "https://files.pythonhosted.org/packages/92/62/171af2737950fabbf25cdcf7dd8b05dbe1bc6fc30fb4d40e1b4da013cc06/pydantic_monty-0.0.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6785dd86af61e0898b6dd6162db57b1fa5d971572ec30b0b60eb79f9034f0e1d", size = 6768818, upload-time = "2026-03-10T14:46:27.624Z" }, + { url = "https://files.pythonhosted.org/packages/8d/7c/5ee596f92977dae7f9d677432da7883483be512efa0432b416d67f42155e/pydantic_monty-0.0.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e036bcf26fa7ac0bc8d370b7786faef9e89e4dd9e690d2e02abb96cb534f1382", size = 6504276, upload-time = "2026-03-10T14:45:45.166Z" }, + { url = "https://files.pythonhosted.org/packages/d3/a8/2edc58fe3453941f75a2933e7bffa08f30a448d3e28f81679950fafec681/pydantic_monty-0.0.8-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1a7341af27a08112e78ad7879549b1c2684b6405d14891bf972a0c51b845b2a4", size = 6764775, upload-time = "2026-03-10T14:45:09.916Z" }, + { url = "https://files.pythonhosted.org/packages/b5/80/55bf6c7b76786142c14cd6dc8aa350e8cfd0e04ee256bd7b267caf105060/pydantic_monty-0.0.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2e1a3b4150756e6e16ea6f57c2da18839d4754b1ef564acf8750156783b8631", size = 7324970, upload-time = "2026-03-10T14:45:32.139Z" }, + { url = "https://files.pythonhosted.org/packages/21/dc/1f578581e1a09dff04df303327aac2a23160da7906d9b7898ed462980f43/pydantic_monty-0.0.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0133ee9335f18e049bc0c9d40e81b995b30c6ac2da0dde8549d883599f0b3791", size = 7530638, upload-time = "2026-03-10T14:46:06.424Z" }, + { url = "https://files.pythonhosted.org/packages/39/21/c783a54695d76673e1fe7285edd99b241f0743e00dd99f3645d7ddea6585/pydantic_monty-0.0.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d515ff3a284757016ee741bb337072a228c9ac8ebe63ff31263f337651eaf3c", size = 7298036, upload-time = "2026-03-10T14:46:46.346Z" }, + { url = "https://files.pythonhosted.org/packages/ca/38/36838e08baac3a675c382f4f27d651e4aa79a77d08b5f5efb45bc95a43a2/pydantic_monty-0.0.8-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:687bce9965b2be994cc72c6e3d6f9debf578beba7b33ce26160db209d41aed3f", size = 7186272, upload-time = "2026-03-10T14:46:26.088Z" }, + { url = "https://files.pythonhosted.org/packages/23/d2/eaad2d0c35451ce1697b8afc3d3bcab263a5bbaf71d1bee91a35a48a9a5d/pydantic_monty-0.0.8-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:aae593b5fd001f7ba026b6fe780028a0fc951c54f39e3dd1c232c225243be17d", size = 6678215, upload-time = "2026-03-10T14:46:44.914Z" }, + { url = "https://files.pythonhosted.org/packages/60/f3/031c7277082f60e43dd14490f9e50262ba6298df2793872df5f55da6ec97/pydantic_monty-0.0.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:199c8135a1ac4ce4eee477d529aaff1986558542109cf207bf2071cf400006da", size = 7136708, upload-time = "2026-03-10T14:45:41.191Z" }, + { url = "https://files.pythonhosted.org/packages/e9/d7/b7d94ab8e74f2431ed315c554d00f919674c650a0d03c2f22f91f3f62d53/pydantic_monty-0.0.8-cp310-cp310-win32.whl", hash = "sha256:643b6285fc921b5a73e23dcd71315a38eebf33bd75ba10deb3736c8bf8171d2d", size = 6585407, upload-time = "2026-03-10T14:46:11.335Z" }, + { url = "https://files.pythonhosted.org/packages/21/59/53af269e465828a48b544a5d42bb3278a4ba1bf657169358ba4575713d54/pydantic_monty-0.0.8-cp310-cp310-win_amd64.whl", hash = "sha256:85a0378ae97dbc56aecfbe9e097c5bb8b3d402f44ecf4de63016711751f9b9bb", size = 7363913, upload-time = "2026-03-10T14:46:37.859Z" }, + { url = "https://files.pythonhosted.org/packages/31/32/ac657c9e517665cc74bf47c326a2b9585b1ed0a9ebba5d624c8ad3cf3892/pydantic_monty-0.0.8-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b36cd5b8e380d9624a6929203a35e8e0d6e11b2bdb840553f0079000c1acf84e", size = 6700260, upload-time = "2026-03-10T14:45:19.034Z" }, + { url = "https://files.pythonhosted.org/packages/a7/0b/042f23c8211cc74b508eb7a648f7dd3ce6b39ef6fb72ebd8f1f6c060f29b/pydantic_monty-0.0.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:82adee95051eab2b04d11e4a11e852eafa2d00d8adacf62176ec25bd0d3f6a7b", size = 6767968, upload-time = "2026-03-10T14:46:14.422Z" }, + { url = "https://files.pythonhosted.org/packages/77/9a/c0d387b8b3ad6255645044afc8ac3f1cece7f8d91e819aaefeb738251634/pydantic_monty-0.0.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f7840f0e0d3933dc3c2f64402275d9176877ecbe565008dee657305f5c6e40f", size = 6502971, upload-time = "2026-03-10T14:46:22.841Z" }, + { url = "https://files.pythonhosted.org/packages/d8/58/ebda094d5ac8fbdd74dfc2041d7a04289fd13a9ad19accdf22841ef00d79/pydantic_monty-0.0.8-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b5a904d8d489bb6494458364874af31a8a933f7b316234fd69b25edf6b0567b9", size = 6764397, upload-time = "2026-03-10T14:46:12.921Z" }, + { url = "https://files.pythonhosted.org/packages/52/58/9a89b514b2953f85cc089149bcbc1ed4a79d9ee90cd6448ce85d71d04ece/pydantic_monty-0.0.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c44c7464c9286a8c39ee58bf45d0a477e2f281c49ce99796210874f58bbe52f7", size = 7325434, upload-time = "2026-03-10T14:46:01.563Z" }, + { url = "https://files.pythonhosted.org/packages/b1/fe/cf2c8556589d11f85ab82ff1cc2ead3f1e70a1bdbc03e6b0648bf66186c7/pydantic_monty-0.0.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:94644e34771e4255772d41e31c30c9ab1f9df06860507cb5f47152760a15767d", size = 7530790, upload-time = "2026-03-10T14:45:11.566Z" }, + { url = "https://files.pythonhosted.org/packages/f8/56/b981e8e29e316d3f327142fc206c4d727ecaa1b875057110b3e754c1fc52/pydantic_monty-0.0.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70fa4aba236678a1d0b04e07300d404f2504351922b6aecec7aa4b4d84785d3e", size = 7297432, upload-time = "2026-03-10T14:46:36.305Z" }, + { url = "https://files.pythonhosted.org/packages/f3/4e/8b41a6e1461e52478d8ca3ef19db6e271c4b7f27e0073337b3f98e35d3a4/pydantic_monty-0.0.8-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8df1d1fa0deed81de60e43e3ab8c4a82f63998b6e04e535e681633b4eaff98b6", size = 7185726, upload-time = "2026-03-10T14:45:05.965Z" }, + { url = "https://files.pythonhosted.org/packages/eb/be/28f86c4c7679359c9a2fcd17392d7716550e87b8a4d381cf865bc1014a31/pydantic_monty-0.0.8-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f0344a975c59f5f35ca6c0c2b332d7d8d7c681449c5ce8849181ed05025b18c4", size = 6677665, upload-time = "2026-03-10T14:46:31.198Z" }, + { url = "https://files.pythonhosted.org/packages/31/68/82374a2e8a832f72c65a4a26d20fdb756c089cbce0984d7153f70fbb05fb/pydantic_monty-0.0.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8aad6b5eee8d8c1dda9b63b6ea3b7c524b64b8c06ad49ba9cc015740c4f86c78", size = 7136490, upload-time = "2026-03-10T14:45:17.199Z" }, + { url = "https://files.pythonhosted.org/packages/63/46/4b487728ed5a97341a21fd54a32ec11552757e843bb69983f72eae0caa41/pydantic_monty-0.0.8-cp311-cp311-win32.whl", hash = "sha256:d7e16a25b0066e34a2694dac157fc93971f80c67c4f1b47ad6e5049423141e6a", size = 6585220, upload-time = "2026-03-10T14:45:47.124Z" }, + { url = "https://files.pythonhosted.org/packages/87/d6/b2665d7590849969da9dbde08fdc1810f8ded7891d8fff3488ceb557eda2/pydantic_monty-0.0.8-cp311-cp311-win_amd64.whl", hash = "sha256:85cdec430afb4bbe1860aadf910c1b00832037ce08a2fb64e6012927a58ea14f", size = 7364044, upload-time = "2026-03-10T14:45:59.698Z" }, + { url = "https://files.pythonhosted.org/packages/f4/4f/7d7c7531be850469bccfbbf3cfede9e95d92b1d8e7b245d9dc77a599d6e4/pydantic_monty-0.0.8-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:025b380c0ed728bdf88e3ed60d5977498d4bb9da61cd81d9cabdaecce16f5755", size = 6699454, upload-time = "2026-03-10T14:46:39.422Z" }, + { url = "https://files.pythonhosted.org/packages/c0/35/074daa5fd92e4a5c1e49c8fae06036e194139feef9a51b4db82d0fee7e54/pydantic_monty-0.0.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76846d77ebda3414beb8c9e5da9eaef722e06606dc0307613186d888e775dc51", size = 6743273, upload-time = "2026-03-10T14:46:16.184Z" }, + { url = "https://files.pythonhosted.org/packages/55/0d/e6cb1e9e1c2e51501d8f7848c18803780164948e338350ab94769690f207/pydantic_monty-0.0.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b58149d6a8998ffed80f626c10bf0de7e65964aae1dfad80db8ecb00968fb1a", size = 6503552, upload-time = "2026-03-10T14:45:26.083Z" }, + { url = "https://files.pythonhosted.org/packages/5f/a9/286f7eb6c95d7877f6f8b9bcbe26e2d988fa142cd77509e48e67302478bf/pydantic_monty-0.0.8-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1b99bb030ff8e95b160c11702b9a6823705b7dcc1a49a2c1dccc43c2538bfe27", size = 6765377, upload-time = "2026-03-10T14:45:14.472Z" }, + { url = "https://files.pythonhosted.org/packages/74/76/85268f6305bc5b153be7c0860e69ce3b9ba916daa4a419f53d8a777e9a39/pydantic_monty-0.0.8-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b05b17fc2b8875e0efea047416fc0cd28a8018c040b25387454996b98540ccfc", size = 7324571, upload-time = "2026-03-10T14:45:48.786Z" }, + { url = "https://files.pythonhosted.org/packages/11/6a/2855c6149f6ba3138c7bfb009c07d528e2df12802e3216928b1289bbf233/pydantic_monty-0.0.8-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97fa5deef4d8cdcf60a4d8a4e7a34099deae1ad4b33acc1c33ac2e1cc348a1c2", size = 7533828, upload-time = "2026-03-10T14:46:55.386Z" }, + { url = "https://files.pythonhosted.org/packages/31/63/44b5bb5798323f7f735f5855a0d3f478d6852d9d1774427d77e31b5dbffb/pydantic_monty-0.0.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f1aabd56af5db51c4af512142fe7ffc866b6a20dcd039a4aad5245ef42768f21", size = 7271978, upload-time = "2026-03-10T14:45:33.903Z" }, + { url = "https://files.pythonhosted.org/packages/4b/68/ce82eb571e45afbe3c0b9544fe3ebf93f841ec895fea0d39c9604a0a421f/pydantic_monty-0.0.8-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9782e58da29176a3fb0ce8ec808571d0be99c1e2ad167637225246e594e85f13", size = 7187140, upload-time = "2026-03-10T14:46:41.784Z" }, + { url = "https://files.pythonhosted.org/packages/5b/45/778bc260195ddb892f284c3cb8ab8cbcb0542e6a18f11b7e50a592b507ba/pydantic_monty-0.0.8-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ff7fb7ff3f4e830ef9b28a76e634219195b552c0dfed3471fb4910708db56221", size = 6677996, upload-time = "2026-03-10T14:46:03.163Z" }, + { url = "https://files.pythonhosted.org/packages/1e/02/8e8396b83d19ec70a09c24b0245177a595c2b7d6d092c6f6c7d3310b191c/pydantic_monty-0.0.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce7222d923676900827e3951e65f854d65f50b65ed8c7a84011d3472773d51bf", size = 7136513, upload-time = "2026-03-10T14:45:52.179Z" }, + { url = "https://files.pythonhosted.org/packages/fd/24/53d06af74b82043be9c4960c38bbc255eeb3de9f2a4d450f942912630bb3/pydantic_monty-0.0.8-cp312-cp312-win32.whl", hash = "sha256:4cb59e1e7b1a3d573247a871a87c88506ce9fb68bbd7648598e171cf2f04747e", size = 6582168, upload-time = "2026-03-10T14:45:58.197Z" }, + { url = "https://files.pythonhosted.org/packages/25/0f/55a16faf379139263ad852421734cb096390777d299ef7f185ec10404656/pydantic_monty-0.0.8-cp312-cp312-win_amd64.whl", hash = "sha256:45bf11e3b795cc470a91cbd7cfeb9d96f7a60387e8da146a11bf952b4371aba7", size = 7335422, upload-time = "2026-03-10T14:45:28.301Z" }, + { url = "https://files.pythonhosted.org/packages/27/46/3268001a639052515d5a55ea2f1e087ae1e5f7aa9d7bc62c4808d731fff1/pydantic_monty-0.0.8-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f303f1d979213dcb365c69de3912daa70a32af019bb5ee86e086f584638974f2", size = 6698529, upload-time = "2026-03-10T14:45:39.712Z" }, + { url = "https://files.pythonhosted.org/packages/f4/22/28d6b7f8f7a1e0881c449310a6f2c8ee0cf85dc4434ae0f7e633dfcf5bcf/pydantic_monty-0.0.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b8cc698da6f7f11743df7c48b959eea000e8e30511c7e1318b67543e25985062", size = 6743849, upload-time = "2026-03-10T14:46:48.023Z" }, + { url = "https://files.pythonhosted.org/packages/7b/ab/e3dfe057af472e4065297d69aea0cf30616d01366a29e02d0a2739f22b93/pydantic_monty-0.0.8-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90f6da093164e99012b49b39ba1e64c166ccd126128d114f9e2c66df6fb695c4", size = 6503266, upload-time = "2026-03-10T14:46:19.594Z" }, + { url = "https://files.pythonhosted.org/packages/0a/82/d3e9aa9bd9ac69b3584216166a189e11370913ffcbd2a57a5cac6ce2d4ba/pydantic_monty-0.0.8-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f2432ef62140554f5970991b7df41e064824741a807515658f91283c75669086", size = 6765032, upload-time = "2026-03-10T14:45:20.571Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9d/7ab11be8eff998bf9283c7bf444254cff5212f87fe3f2a9a2f7436cce6d7/pydantic_monty-0.0.8-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:99951afb4d722212c2ce85c303b5e87d50756adaff46817ad0e60e95eb1e5141", size = 7324673, upload-time = "2026-03-10T14:45:37.957Z" }, + { url = "https://files.pythonhosted.org/packages/39/a0/7f026ece228cc990e58aa27f2ce5af26d042baaa0186cb451f3e04ff0abe/pydantic_monty-0.0.8-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bdef70921b408378f9bc38bce4952d3c176b3f7833aa782068b310390901c516", size = 7533774, upload-time = "2026-03-10T14:46:24.265Z" }, + { url = "https://files.pythonhosted.org/packages/07/e7/72f250ffd005520ad8cdffb387241a9dd92fc70bcbdd769720918bd34495/pydantic_monty-0.0.8-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00b5dee7bef619e7661f77bef765f483605bcd2a79c6b8bf5c910afa1c93fc40", size = 7272179, upload-time = "2026-03-10T14:45:08.085Z" }, + { url = "https://files.pythonhosted.org/packages/b8/7c/97c87c2a315ba4376bfb83197586365003005e01d6547df7f6d77b80d13f/pydantic_monty-0.0.8-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:5f4c6f6fb2ebae0bc80520af71e1057055413ca96fd33f13f3c612a085e50fb8", size = 7186684, upload-time = "2026-03-10T14:46:49.833Z" }, + { url = "https://files.pythonhosted.org/packages/09/42/2eea55906fee8bef7c1024bf2d35e2c301b15b562934731bdae25db448cb/pydantic_monty-0.0.8-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3ab3c558a648942d4d31e7f668d60d2a2e129751a3e8e9dc27b1e6d635e9e627", size = 6677272, upload-time = "2026-03-10T14:45:54.037Z" }, + { url = "https://files.pythonhosted.org/packages/40/0a/066e53d4693b680e39080d3af4f234c1ff9976f7621b8a49dd2a41710431/pydantic_monty-0.0.8-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:782d686d217b537e6fa9047aa3444d67b3b7cc69bf6faf34078c992ceaeb1e9e", size = 7136493, upload-time = "2026-03-10T14:45:30.014Z" }, + { url = "https://files.pythonhosted.org/packages/81/c1/46dc300f87314aef883a57ae5a35ba45a463c09f64d3d9c9f0b620672734/pydantic_monty-0.0.8-cp313-cp313-win32.whl", hash = "sha256:ef0db454757cb92974890b11c7b0969d8c0d95944c821f6fc6cd23846d1d2aa4", size = 6581014, upload-time = "2026-03-10T14:46:07.886Z" }, + { url = "https://files.pythonhosted.org/packages/b0/2a/cf156df19a1612ca5ddcbd982645e54c5485e83215d9532cc9aff791f854/pydantic_monty-0.0.8-cp313-cp313-win_amd64.whl", hash = "sha256:dbf8c7cfaff2b345f8c1bfba98fdc282e790fef5c601f37c8d5355ccd45073de", size = 7335041, upload-time = "2026-03-10T14:46:21.371Z" }, + { url = "https://files.pythonhosted.org/packages/a5/69/5bedc7ad67fdd9e4f04007477f5c414b54c51d47c5dfdd7abad4c78663aa/pydantic_monty-0.0.8-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:47990770ed74af1e8e2160a7dd925aa1e4fd1bf4ae9658ffd1cdc32eb5c8c6e8", size = 6700454, upload-time = "2026-03-10T14:45:56.247Z" }, + { url = "https://files.pythonhosted.org/packages/bf/35/06aaa9c766a83e7e95219bfb6cb88bd0a87803f85eb3f596b82da0cb3009/pydantic_monty-0.0.8-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9534f6236b3ebdd09f0f7d71e508e68736733d24cdda2a6c0a758914261c75e5", size = 6760180, upload-time = "2026-03-10T14:45:22.982Z" }, + { url = "https://files.pythonhosted.org/packages/f3/b8/eaa4a4f0b3a1c343773317027bb5d1e11a1b0c1ad01d3f0921a7f547d346/pydantic_monty-0.0.8-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76c4ffd8275d520628732d82e7410e5f1f69ef5af274d676f1f2f9221fd85a34", size = 6504414, upload-time = "2026-03-10T14:45:12.984Z" }, + { url = "https://files.pythonhosted.org/packages/9d/79/dbb0875ad2b565d7ef1981feda4438ab743130f4031107aaadc9212488dd/pydantic_monty-0.0.8-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:badaa22749fa7a22ee2cb88f6deb48fc191cb2f41237dd630593259ed9f30b67", size = 6766768, upload-time = "2026-03-10T14:46:29.12Z" }, + { url = "https://files.pythonhosted.org/packages/25/cb/bd6bef8fa2cfe807dd5ef36ab8ce6d094ecdad0050cf57dec4c8a438d413/pydantic_monty-0.0.8-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e39104684eda1ae3c290e3a4369bef9d00f9b637be8515802bc1523fecb5160d", size = 7326073, upload-time = "2026-03-10T14:45:50.679Z" }, + { url = "https://files.pythonhosted.org/packages/aa/44/b64b6e2857519d5fdc66f74998019e585e5c3bf2cf8deb7b77419d29db2a/pydantic_monty-0.0.8-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8c2e5b3b11b794f82504bb6929f4101f0db84933986aee9b46ce7561add152e9", size = 7535251, upload-time = "2026-03-10T14:46:04.645Z" }, + { url = "https://files.pythonhosted.org/packages/da/11/4e1524d94e33427990cffd74eaf94a023724c872f11d42ac90eebd74ccc0/pydantic_monty-0.0.8-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a94fc87b2a9b1b66a09f35b819b1c7e7da7702b21c923cdca21fa13129486b45", size = 7288739, upload-time = "2026-03-10T14:45:35.609Z" }, + { url = "https://files.pythonhosted.org/packages/2c/7d/7e667c72c9742a6725b04d21faf50e8f552a0e23da9ca1cbeba891961c9a/pydantic_monty-0.0.8-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:70b7479e7bd47d274fbd6278a625289e71b1a4aad400f40c7a4a89fe9a50952c", size = 7189195, upload-time = "2026-03-10T14:46:09.399Z" }, + { url = "https://files.pythonhosted.org/packages/7b/3f/3243277e2c6bf4c3e4684fbc78c316a8965dbbb012b573b38978d8628bfc/pydantic_monty-0.0.8-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:48b8c9cb9a96a5f28b03669771368a444c8b50c758b6a5bd0d148fa02db3f65b", size = 6679333, upload-time = "2026-03-10T14:45:24.494Z" }, + { url = "https://files.pythonhosted.org/packages/0c/51/88152b89f5288e9d3cd28e71f79a6567421e20d8f00b0514b7819a29a8bb/pydantic_monty-0.0.8-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:b005f10798a1b58a12b5643f69a7b54208bd4965d16ca2bd21d13fca4d1a4f67", size = 7137771, upload-time = "2026-03-10T14:46:33.043Z" }, + { url = "https://files.pythonhosted.org/packages/d8/38/baf5a66ed72b931de7ce92251822d6ffb752582389ff4f4c853742cee029/pydantic_monty-0.0.8-cp314-cp314-win32.whl", hash = "sha256:b143bba29c274e15424a097af6ae85a1e81e6c2fb7184ee7a93fbf10997dbbd3", size = 6584038, upload-time = "2026-03-10T14:46:34.53Z" }, + { url = "https://files.pythonhosted.org/packages/71/80/bfda690914fa15c68f8be9e824e62c4054118a9a72304735221446a89014/pydantic_monty-0.0.8-cp314-cp314-win_amd64.whl", hash = "sha256:f4bc9185bb5f37f3220a978889b4bc6f822a41ee4d5523c17eef4d869aca66e8", size = 7351170, upload-time = "2026-03-10T14:46:57.025Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -553,6 +710,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0" @@ -566,6 +735,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, ] +[[package]] +name = "rich" +version = "14.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, +] + [[package]] name = "ruff" version = "0.15.7"