diff --git a/.gitignore b/.gitignore index a54049e..bc634fe 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ wheels/ # Hypothesis .hypothesis/ .vscode/ +mutants/ diff --git a/README.md b/README.md index 37fee2c..4d4e151 100644 --- a/README.md +++ b/README.md @@ -103,8 +103,8 @@ We studied leading coding agents, agent frameworks, and Claw-style assistants to |---|---|---|---|---| | **Tools & execution** | **Code mode** | Sandboxed Python execution via [Monty](https://github.com/pydantic/monty) -- one `run_code` call replaces N tool calls | :white_check_mark: [Docs](pydantic_ai_harness/code_mode/) | | | | **Tool search** | Progressive tool discovery for large tool sets | :white_check_mark: [Pydantic AI](https://pydantic.dev/docs/ai/tools-toolsets/toolsets/#deferred-loading) | | -| | **File system** | Read, write, edit, search files with path traversal prevention | :construction: [PR #177](https://github.com/pydantic/pydantic-ai-harness/pull/177) | [pydantic-ai-backend](https://github.com/vstorm-co/pydantic-ai-backend) (vstorm‑co) | -| | **Shell** | Execute commands with allowlists, denylists, and timeouts | :construction: [PR #177](https://github.com/pydantic/pydantic-ai-harness/pull/177) | [pydantic-ai-backend](https://github.com/vstorm-co/pydantic-ai-backend) (vstorm‑co) | +| | **File system** | Read, write, edit, search files with path traversal prevention | :white_check_mark: [Docs](pydantic_ai_harness/filesystem/) | [pydantic-ai-backend](https://github.com/vstorm-co/pydantic-ai-backend) (vstorm‑co) | +| | **Shell** | Execute commands with allowlists, denylists, and timeouts | :white_check_mark: [Docs](pydantic_ai_harness/shell/) | [pydantic-ai-backend](https://github.com/vstorm-co/pydantic-ai-backend) (vstorm‑co) | | | **Repo context injection** | Auto-load CLAUDE.md/AGENTS.md and repo structure | :construction: [PR #175](https://github.com/pydantic/pydantic-ai-harness/pull/175) | [pydantic-deep](https://github.com/vstorm-co/pydantic-deepagents) (vstorm‑co) | | | **Verification loop** | Run tests after edits, auto-fix failures | :construction: [PR #169](https://github.com/pydantic/pydantic-ai-harness/pull/169) | | | **Context management** | **Sliding window** | Trim conversation history to stay within token limits | :construction: [PR #191](https://github.com/pydantic/pydantic-ai-harness/pull/191) | [summarization-pydantic-ai](https://github.com/vstorm-co/summarization-pydantic-ai) (vstorm‑co) | diff --git a/docs/mutation-testing.md b/docs/mutation-testing.md new file mode 100644 index 0000000..f7112c8 --- /dev/null +++ b/docs/mutation-testing.md @@ -0,0 +1,47 @@ +# Mutation Testing + +Mutation testing complements the 100% branch-coverage requirement: coverage +proves every line and branch runs, mutation testing proves the assertions +actually pin the behavior down. + +Covers `pydantic_ai_harness/filesystem/_toolset.py` and +`pydantic_ai_harness/shell/_toolset.py`. + +Run with [mutmut](https://mutmut.readthedocs.io/) v3 via `scripts/run-mutmut.sh`, +which installs mutmut ephemerally with `uv run --with` — no dev dependency +required. + +```bash +scripts/run-mutmut.sh run --max-children 1 +scripts/run-mutmut.sh results +scripts/run-mutmut.sh show +``` + +## Interpreting survivors + +A surviving mutant is either a missing test or an equivalent mutant — a change +that produces behavior no test could distinguish from the original. Triage each +survivor; the recurring equivalent-mutant categories in this codebase are: + +- **Trampoline default params** — mutmut v3 wraps functions, and the wrapper + keeps the original defaults, so a mutated default is never observed. +- **Omitted `name=` in `add_function()`** — pydantic-ai falls back to + `method.__name__`, which equals the explicit name being mutated away. +- **`'utf-8'` encoding mutations** — Python's codec lookup is case-insensitive + and UTF-8 is the default text encoding, so case/omission changes are no-ops. +- **`errors='replace'` mutations** — exercised only by invalid bytes; valid + UTF-8 test data never invokes the error handler. +- **Unreachable `except` blocks** (marked `pragma: no cover`) — paths that + can't be triggered in the test environment. +- **`CancelScope(shield=True)` flips** — require an outer cancellation during + the near-instant cleanup window. + +Anything outside these categories should be treated as a real gap and killed +with a new test. + +## Limitations + +Trio-parametrized tests are excluded during mutation testing (`-k 'not trio'` +in `pyproject.toml [tool.mutmut]`) because trio segfaults in mutmut's +subprocess environment on Python 3.14 / macOS. The kill rate is unaffected — +the trio tests exercise the same code paths as the asyncio tests. diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 0a60fd7..4f6f62d 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -1,11 +1,13 @@ -"""The batteries for your Pydantic AI agent -- the official capability library.""" +"""Pydantic AI capability library.""" from typing import TYPE_CHECKING if TYPE_CHECKING: from .code_mode import CodeMode + from .filesystem import FileSystem + from .shell import Shell -__all__ = ['CodeMode'] +__all__ = ['CodeMode', 'FileSystem', 'Shell'] def __getattr__(name: str) -> object: @@ -13,4 +15,12 @@ def __getattr__(name: str) -> object: from .code_mode import CodeMode return CodeMode + elif name == 'FileSystem': + from .filesystem import FileSystem + + return FileSystem + elif name == 'Shell': + from .shell import Shell + + return Shell raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/pydantic_ai_harness/filesystem/README.md b/pydantic_ai_harness/filesystem/README.md new file mode 100644 index 0000000..5a8b0a1 --- /dev/null +++ b/pydantic_ai_harness/filesystem/README.md @@ -0,0 +1,136 @@ +# FileSystem + +Give an agent sandboxed, pattern-filtered access to a directory tree. + +## The problem + +Letting an agent touch the filesystem directly is risky: path traversal +(`../../etc/passwd`), symlinks that escape the project, clobbering `.git`, or +leaking `.env` secrets. Hand-rolling the guards around every tool call is +repetitive and easy to get subtly wrong. + +## The solution + +`FileSystem` exposes a fixed set of file tools, all scoped to a single +`root_dir`. Every path is resolved and containment-checked (symlinks included) +before any I/O, and access is filtered through allow / deny / protected glob +patterns. + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import FileSystem + +agent = Agent( + 'anthropic:claude-sonnet-4-6', + capabilities=[FileSystem(root_dir='./workspace')], +) + +result = agent.run_sync('Read config.toml and tell me the package name.') +print(result.output) +``` + +## Tools + +| Tool | Purpose | +|---|---| +| `read_file` | Read a text file with line numbers and a content hash. Binary files are detected and not dumped. | +| `write_file` | Create or overwrite a file. Optional `expected_hash` rejects stale writes (optimistic concurrency). | +| `edit_file` | Exact-string replacement; `old_text` must match exactly once. Optional `expected_hash`. | +| `list_directory` | List a directory's entries with type indicators and sizes. | +| `search_files` | Regex search over file contents, optionally narrowed by an `include_glob`. | +| `find_files` | Glob search over file names (e.g. `*.py`, `**/*.json`). | +| `create_directory` | Create a directory and any missing parents. | +| `file_info` | Metadata for a file or directory (size, type, line count, hash, symlink target). | + +## Security model + +- **Containment.** Paths resolve relative to `root_dir`; anything resolving + outside — via `..`, an absolute path, or a symlink — is rejected. Symlinks + are resolved with `os.path.realpath` *before* the containment check, closing + the TOCTTOU window. +- **Binary detection.** `read_file` returns a placeholder instead of dumping + binary bytes into the model context. +- **Optimistic concurrency.** `write_file`/`edit_file` accept an + `expected_hash` so an agent operating on a stale read is told to re-read + rather than silently overwriting newer content. + +## Pattern filtering + +Three independent glob lists control access. Patterns are matched with +`fnmatch`, whose `*` spans `/`, so `*.py` matches `src/main.py` and you rarely +need `**`. + +| Field | Effect | +|---|---| +| `allowed_patterns` | If non-empty, only matching paths are accessible (allowlist). | +| `denied_patterns` | Matching paths are always rejected (denylist). | +| `protected_patterns` | Matching paths are read-only — reads succeed, writes are rejected. | + +`protected_patterns` defaults to `.git/`, `.env`/`.env.*`, `*.pem`, `*.key`, +and `**/secrets*`. Pass an empty list to disable protection. + +### Direct access vs. walkers + +The three rules apply at two different granularities: + +- **Direct access** (`read_file`, `write_file`, `edit_file`, `file_info`, + `create_directory`) gates the operation's target path. You must name a path + that the patterns permit. +- **Walkers** (`list_directory`, `search_files`, `find_files`) gate their root + by deny/protected patterns, but **not** by `allowed_patterns` — a directory + root like `.` never matches a file pattern such as `src/*.py`, so requiring + it to would make every listing fail. Instead, the root is always walked and + each **entry** is filtered against all three lists. A directory listing can + never surface a path the agent couldn't otherwise read or write. + +So with `allowed_patterns=['*.py']`, `list_directory('.')` succeeds and shows +only the `.py` entries; `read_file('notes.md')` is rejected. + +> Dotfiles and dot-directories (`.git`, `.env`, `.github`, …) are skipped by +> all three walkers — `list_directory`, `search_files`, and `find_files` — +> regardless of patterns. + +## Configuration + +```python +FileSystem( + root_dir='.', # str | Path — sandbox root + allowed_patterns=[], # allowlist globs (empty = allow all) + denied_patterns=[], # denylist globs + protected_patterns=[...], # read-only globs (defaults to secrets/.git) + max_read_lines=2000, # cap for a single read_file + max_search_results=1000, # cap for search_files + max_find_results=1000, # cap for find_files +) +``` + +The integer limits must be positive; they are validated at construction. + +## Agent spec (YAML/JSON) + +`FileSystem` works with Pydantic AI's +[agent spec](https://ai.pydantic.dev/agent-spec/): + +```yaml +# agent.yaml +model: anthropic:claude-sonnet-4-6 +capabilities: + - FileSystem: + root_dir: ./workspace + allowed_patterns: ['*.py', '*.toml'] +``` + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import FileSystem + +agent = Agent.from_file('agent.yaml', custom_capability_types=[FileSystem]) +``` + +Pass `custom_capability_types` so the spec loader knows how to instantiate +`FileSystem`. + +## Further reading + +- [Pydantic AI capabilities](https://ai.pydantic.dev/capabilities/) +- [Toolsets](https://ai.pydantic.dev/toolsets/) diff --git a/pydantic_ai_harness/filesystem/__init__.py b/pydantic_ai_harness/filesystem/__init__.py new file mode 100644 index 0000000..5c73527 --- /dev/null +++ b/pydantic_ai_harness/filesystem/__init__.py @@ -0,0 +1,6 @@ +"""Filesystem capability: gives agents configurable, sandboxed file system access.""" + +from pydantic_ai_harness.filesystem._capability import FileSystem +from pydantic_ai_harness.filesystem._toolset import FileSystemToolset + +__all__ = ['FileSystem', 'FileSystemToolset'] diff --git a/pydantic_ai_harness/filesystem/_capability.py b/pydantic_ai_harness/filesystem/_capability.py new file mode 100644 index 0000000..218b875 --- /dev/null +++ b/pydantic_ai_harness/filesystem/_capability.py @@ -0,0 +1,81 @@ +"""Filesystem capability that provides sandboxed file system access.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from pydantic_ai.capabilities import AbstractCapability +from pydantic_ai.tools import AgentDepsT +from pydantic_ai.toolsets import AgentToolset + +from pydantic_ai_harness.filesystem._toolset import FileSystemToolset + +_DEFAULT_PROTECTED: list[str] = [ + '.git/*', + '.env', + '.env.*', + '*.pem', + '*.key', + '**/secrets*', +] + + +@dataclass +class FileSystem(AbstractCapability[AgentDepsT]): + """File system access scoped to a root directory. + + All paths are resolved relative to `root_dir`. Traversal above the root + is rejected. Symlinks are resolved before authorization. + """ + + root_dir: str | Path = '.' + """Root directory for all file operations. Defaults to the current directory.""" + + allowed_patterns: Sequence[str] = field(default_factory=list[str]) + """If non-empty, only paths matching at least one glob pattern are accessible.""" + + denied_patterns: Sequence[str] = field(default_factory=list[str]) + """Paths matching any of these glob patterns are rejected.""" + + protected_patterns: Sequence[str] = field(default_factory=lambda: list(_DEFAULT_PROTECTED)) + """Paths matching these patterns are read-only (writes are rejected). + + Defaults to protecting `.git/`, `.env`, key files, and secrets. + Set to an empty list to disable protection. + """ + + max_read_lines: int = 2000 + """Maximum number of lines returned by a single `read_file` call.""" + + max_search_results: int = 1000 + """Maximum number of matches returned by `search_files`.""" + + max_find_results: int = 1000 + """Maximum number of matches returned by `find_files`.""" + + def __post_init__(self) -> None: + # Runtime validation: dataclass field annotations are advisory, not enforced. + # A config-driven caller could pass a string that would otherwise propagate. + values: dict[str, Any] = { + 'max_read_lines': self.max_read_lines, + 'max_search_results': self.max_search_results, + 'max_find_results': self.max_find_results, + } + for name, value in values.items(): + if not isinstance(value, int) or value <= 0: + raise ValueError(f'{name} must be a positive integer, got {value!r}') + + def get_toolset(self) -> AgentToolset[AgentDepsT]: + """Build and return the filesystem toolset.""" + return FileSystemToolset[AgentDepsT]( + root_dir=Path(self.root_dir), + allowed_patterns=self.allowed_patterns, + denied_patterns=self.denied_patterns, + protected_patterns=self.protected_patterns, + max_read_lines=self.max_read_lines, + max_search_results=self.max_search_results, + max_find_results=self.max_find_results, + ) diff --git a/pydantic_ai_harness/filesystem/_toolset.py b/pydantic_ai_harness/filesystem/_toolset.py new file mode 100644 index 0000000..94d3ce4 --- /dev/null +++ b/pydantic_ai_harness/filesystem/_toolset.py @@ -0,0 +1,507 @@ +"""Filesystem toolset providing sandboxed file operations.""" + +from __future__ import annotations + +import fnmatch +import functools +import hashlib +import os +import re +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path +from typing import Concatenate, ParamSpec + +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.tools import AgentDepsT +from pydantic_ai.toolsets import FunctionToolset + +_P = ParamSpec('_P') + +# Errors that mean "the model asked for something the tool couldn't do" — a +# missing file, a denied path, a stale edit. pyai only feeds `ModelRetry` back +# to the model; any other exception aborts the whole run. `_recoverable` +# converts these so the agent can correct itself and continue. +_RECOVERABLE_ERRORS = (PermissionError, FileNotFoundError, NotADirectoryError, IsADirectoryError, ValueError) + + +def _recoverable( + fn: Callable[Concatenate[FileSystemToolset, _P], Awaitable[str]], +) -> Callable[Concatenate[FileSystemToolset, _P], Awaitable[str]]: + """Surface model-correctable tool errors as `ModelRetry`.""" + + @functools.wraps(fn) + async def wrapper(self: FileSystemToolset, *args: _P.args, **kwargs: _P.kwargs) -> str: + try: + return await fn(self, *args, **kwargs) + except _RECOVERABLE_ERRORS as e: + raise ModelRetry(str(e)) from e + + return wrapper + + +def _format_lines(lines: Sequence[str], offset: int, limit: int) -> str: + """Format pre-split lines with line numbers and continuation hint.""" + total = len(lines) + + if total == 0: + return '(empty file)\n' + + if offset >= total: + raise ValueError(f'Offset {offset} exceeds file length ({total} lines).') + + selected = lines[offset : offset + limit] + numbered = [f'{i:>6}\t{line}' for i, line in enumerate(selected, start=offset + 1)] + result = ''.join(numbered) + if not result.endswith('\n'): + result += '\n' + + remaining = total - (offset + len(selected)) + if remaining > 0: + next_offset = offset + len(selected) + result += f'... ({remaining} more lines. Use offset={next_offset} to continue reading.)\n' + + return result + + +def _is_binary(data: bytes, sample_size: int = 8192) -> bool: + """Detect binary content by checking for null bytes in the sample.""" + return b'\x00' in data[:sample_size] + + +def _content_hash(content: str) -> str: + """Compute a short content hash for conflict detection.""" + return hashlib.sha256(content.encode('utf-8')).hexdigest()[:12] + + +class FileSystemToolset(FunctionToolset[AgentDepsT]): + """Toolset providing filesystem operations scoped to a root directory. + + Security model: + - All paths resolved relative to root with canonical path checks + - Symlinks resolved before authorization (prevents TOCTTOU) + - Glob-based allow/deny filtering + - Protected path patterns (e.g. `.git/`, `.env`) + - Binary file detection blocks text operations + """ + + def __init__( + self, + *, + root_dir: Path, + allowed_patterns: Sequence[str], + denied_patterns: Sequence[str], + protected_patterns: Sequence[str], + max_read_lines: int, + max_search_results: int, + max_find_results: int, + ) -> None: + super().__init__() + self._root = root_dir.resolve() + self._real_root = Path(os.path.realpath(self._root)) + self._allowed_patterns = list(allowed_patterns) + self._denied_patterns = list(denied_patterns) + self._protected_patterns = list(protected_patterns) + self._max_read_lines = max_read_lines + self._max_search_results = max_search_results + self._max_find_results = max_find_results + + self.add_function(self.read_file, name='read_file') + self.add_function(self.write_file, name='write_file') + self.add_function(self.edit_file, name='edit_file') + self.add_function(self.list_directory, name='list_directory') + self.add_function(self.search_files, name='search_files') + self.add_function(self.find_files, name='find_files') + self.add_function(self.create_directory, name='create_directory') + self.add_function(self.file_info, name='file_info') + + def _matches(self, path: str, pattern: str) -> bool: + """Glob-match a relative path, treating a leading `**/` as 'any directory, including the root'. + + `fnmatch` has no recursive `**`, so a bare `**/secrets*` would miss a + root-level `secrets.yaml` — there's no leading directory to match. + Retrying with the `**/` prefix stripped covers the zero-directory case. + """ + if fnmatch.fnmatch(path, pattern): + return True + if pattern.startswith('**/'): + return fnmatch.fnmatch(path, pattern[3:]) + return False + + def _first_matching_pattern(self, path: str, patterns: list[str]) -> str | None: + """Return the first pattern that matches path, or None.""" + return next((p for p in patterns if self._matches(path, p)), None) + + def _resolve_path(self, path: str) -> Path: + """Resolve path relative to root, rejecting traversal. + + Uses os.path.realpath for symlink resolution before checking containment. + """ + candidate = (self._root / path).resolve() + real = Path(os.path.realpath(candidate)) + if not real.is_relative_to(self._real_root): + raise PermissionError(f'Path {path!r} resolves outside the root directory.') + + return real + + def _check_access(self, path: str, *, write: bool = False, check_allowed: bool = True) -> None: + """Validate path against allow/deny/protected patterns. + + `check_allowed=False` skips the `allowed_patterns` gate. Walkers + (`list_directory`, `search_files`, `find_files`) pass it so their root + directory isn't required to match `allowed_patterns` itself — `.` or + `src` would never match a file pattern like `src/*.py`. The walk's + entries are still filtered against `allowed_patterns` per-entry via + `_is_accessible`. Denied and protected patterns continue to gate the + root. + """ + if write and self._protected_patterns: + matched = self._first_matching_pattern(path, self._protected_patterns) + if matched: + raise PermissionError(f'Path {path!r} is protected (matches {matched!r}).') + + if self._denied_patterns: + matched = self._first_matching_pattern(path, self._denied_patterns) + if matched: + raise PermissionError(f'Path {path!r} is denied by pattern {matched!r}.') + + if check_allowed and self._allowed_patterns: + if not any(self._matches(path, p) for p in self._allowed_patterns): + raise PermissionError(f'Path {path!r} does not match any allowed pattern.') + + def _is_accessible(self, path: str, *, write: bool = False) -> bool: + """Predicate form of `_check_access` for filtering recursive walkers. + + Used by `list_directory`, `search_files`, and `find_files` to skip + children that would be rejected if accessed directly. Note this only + checks the relative path against patterns; it does not resolve symlinks. + """ + if write and self._protected_patterns: + if self._first_matching_pattern(path, self._protected_patterns) is not None: + return False + if self._denied_patterns: + if self._first_matching_pattern(path, self._denied_patterns) is not None: + return False + if self._allowed_patterns and not any(self._matches(path, p) for p in self._allowed_patterns): + return False + return True + + def _relative_to_root(self, resolved: Path) -> str: + """Canonical path of a resolved location relative to the real root.""" + return str(resolved.relative_to(self._real_root)) + + def _safe_resolve(self, path: str, *, write: bool = False, check_allowed: bool = True) -> Path: + """Resolve and access-check a path in one step. + + Resolution happens first so the access check matches patterns against + the canonical path relative to the root, collapsing `.`/`..`/`//` + segments that would otherwise slip past a literal pattern (e.g. + `config/./secret.txt` evading a `config/secret.txt` deny rule). + """ + resolved = self._resolve_path(path) + self._check_access(self._relative_to_root(resolved), write=write, check_allowed=check_allowed) + return resolved + + @_recoverable + async def read_file(self, path: str, *, offset: int = 0, limit: int | None = None) -> str: + """Read a text file with line numbers. + + Args: + path: File path relative to the root directory. + offset: Zero-based line offset to start reading from. + limit: Maximum number of lines to return (default: 2000). + + Returns: + File content with line numbers, plus metadata header. + """ + if limit is None: + limit = self._max_read_lines + resolved = self._safe_resolve(path) + if not resolved.is_file(): + if resolved.is_dir(): + raise FileNotFoundError(f"'{path}' is a directory, not a file.") + raise FileNotFoundError(f'File not found: {path}') + + raw = resolved.read_bytes() + if _is_binary(raw): + size = len(raw) + return f'[Binary file: {size} bytes. Use a binary-aware tool to inspect.]' + + text = raw.decode('utf-8', errors='replace') + lines = text.splitlines(keepends=True) + content_hash = _content_hash(text) + + header = f'[{path} | {len(lines)} lines | hash:{content_hash}]\n' + return header + _format_lines(lines, offset, limit) + + @_recoverable + async def write_file(self, path: str, content: str, *, expected_hash: str | None = None) -> str: + """Create or overwrite a file with conflict detection. + + Args: + path: File path relative to the root directory. + content: The text content to write. + expected_hash: If provided, the write is rejected when the file exists + and its current hash doesn't match (optimistic concurrency). + + Returns: + Confirmation message with new hash. + """ + resolved = self._safe_resolve(path, write=True) + + # Optimistic concurrency: reject stale writes + if expected_hash is not None and resolved.is_file(): + current = resolved.read_text(encoding='utf-8') + current_hash = _content_hash(current) + if current_hash != expected_hash: + raise ValueError( + f'Conflict: file {path!r} has changed (expected hash:{expected_hash}, ' + f'got hash:{current_hash}). Re-read the file and retry.' + ) + + if not resolved.parent.exists(): + parent_rel = str(resolved.parent.relative_to(self._root)) + raise FileNotFoundError(f"Parent directory '{parent_rel}' does not exist. Use create_directory first.") + resolved.write_text(content, encoding='utf-8') + new_hash = _content_hash(content) + lines = len(content.splitlines()) + return f'Wrote {len(content)} chars ({lines} lines) to {path}. [hash:{new_hash}]' + + @_recoverable + async def edit_file(self, path: str, old_text: str, new_text: str, *, expected_hash: str | None = None) -> str: + """Edit a file by exact string replacement with conflict detection. + + The old_text must appear exactly once in the file. Include surrounding + context lines to ensure uniqueness. + + Args: + path: File path relative to the root directory. + old_text: The exact text to find (must appear exactly once). + new_text: The replacement text. + expected_hash: If provided, rejects the edit when the file's + current hash doesn't match (optimistic concurrency). + + Returns: + Summary with new hash for subsequent operations. + """ + resolved = self._safe_resolve(path, write=True) + if not resolved.is_file(): + raise FileNotFoundError(f'File not found: {path}') + + text = resolved.read_text(encoding='utf-8') + current_hash = _content_hash(text) + + # Optimistic concurrency check + if expected_hash is not None and current_hash != expected_hash: + raise ValueError( + f'Conflict: file {path!r} has changed (expected hash:{expected_hash}, ' + f'got hash:{current_hash}). Re-read the file and retry.' + ) + + count = text.count(old_text) + if count == 0: + raise ValueError(f'old_text not found in {path}.') + if count > 1: + raise ValueError( + f'old_text found {count} times in {path}. Include more surrounding context to make the match unique.' + ) + + new_content = text.replace(old_text, new_text, 1) + resolved.write_text(new_content, encoding='utf-8') + new_hash = _content_hash(new_content) + return f'Edited {path}. [hash:{new_hash}]' + + @_recoverable + async def list_directory(self, path: str = '.') -> str: + """List the contents of a directory. + + Args: + path: Directory path relative to the root directory. + + Returns: + A newline-separated listing with type indicators and sizes. + """ + # The listing root is gated by denied/protected patterns but not by + # allowed_patterns: a directory like '.' never matches a file pattern. + # Entries are filtered per-entry against allowed_patterns below. + resolved = self._safe_resolve(path, check_allowed=False) + if not resolved.is_dir(): + raise NotADirectoryError(f'Not a directory: {path}') + + entries: list[str] = [] + for entry in sorted(resolved.iterdir()): + try: + rel_path = entry.relative_to(self._real_root) + except ValueError: # pragma: no cover + continue + # Skip dotfiles and dot-directories, matching search_files and + # find_files so the three walkers agree on what exists. + if any(part.startswith('.') for part in rel_path.parts): + continue + rel = str(rel_path) + # Apply the same allow/deny/protected filtering used for direct + # access so a directory listing can't leak patterns the agent + # couldn't otherwise read or write. + if not self._is_accessible(rel, write=True): + continue + if entry.is_dir(): + entries.append(f'{rel}/') + else: + try: + size = entry.stat().st_size + except OSError: # pragma: no cover # file deleted between iterdir and stat + size = 0 + entries.append(f'{rel} ({size} bytes)') + return '\n'.join(entries) if entries else '(empty directory)' + + @_recoverable + async def search_files(self, pattern: str, *, path: str = '.', include_glob: str | None = None) -> str: + """Search file contents using a regular expression. + + Args: + pattern: Regex pattern to search for. + path: Directory to search in, relative to the root directory. + include_glob: If provided, only search files matching this glob (e.g. '*.py'). + + Returns: + Matching lines formatted as file:line_number:text. + """ + # See list_directory: the search root isn't gated by allowed_patterns; + # matched files are filtered per-entry below. + resolved = self._safe_resolve(path, check_allowed=False) + try: + compiled = re.compile(pattern) + except re.error as e: + raise ValueError(f'Invalid regex pattern: {e}') from e + + results: list[str] = [] + + if resolved.is_file(): + files = [resolved] + else: + files = sorted(resolved.rglob('*')) + + real_root = Path(os.path.realpath(self._root)) + for file_path in files: + if not file_path.is_file(): + continue + try: + rel_parts = file_path.relative_to(real_root).parts + except ValueError: # pragma: no cover + continue + if any(part.startswith('.') for part in rel_parts): + continue + rel_str = str(file_path.relative_to(real_root)) + # Apply the same allow/deny/protected filtering used for direct + # access so a recursive search can't read patterns the agent + # couldn't otherwise read. + if not self._is_accessible(rel_str, write=True): + continue + if include_glob and not fnmatch.fnmatch(rel_str, include_glob): + continue + try: + raw = file_path.read_bytes() + except OSError: # pragma: no cover + continue + if _is_binary(raw): + continue + text = raw.decode('utf-8', errors='replace') + for line_num, line in enumerate(text.splitlines(), start=1): + if compiled.search(line): + results.append(f'{rel_str}:{line_num}:{line}') + if len(results) >= self._max_search_results: + results.append(f'[... truncated at {self._max_search_results} matches]') + break + + return '\n'.join(results) if results else 'No matches found.' + + @_recoverable + async def find_files(self, pattern: str, *, path: str = '.') -> str: + """Find files by glob pattern (name matching, not content search). + + Args: + pattern: Glob pattern to match (e.g. '*.py', '**/*.json'). + path: Directory to search in, relative to the root directory. + + Returns: + Newline-separated list of matching file paths relative to root. + """ + # See list_directory: the find root isn't gated by allowed_patterns; + # matched entries are filtered per-entry below. + resolved = self._safe_resolve(path, check_allowed=False) + if not resolved.is_dir(): + raise NotADirectoryError(f'Not a directory: {path}') + + matches: list[str] = [] + real_root = Path(os.path.realpath(self._root)) + for match in sorted(resolved.glob(pattern)): + try: + rel_parts = match.relative_to(real_root).parts + except ValueError: # pragma: no cover + continue + if any(part.startswith('.') for part in rel_parts): + continue + rel = str(match.relative_to(real_root)) + # Apply the same allow/deny/protected filtering used for direct + # access so a glob find can't surface patterns the agent + # couldn't otherwise see. + if not self._is_accessible(rel, write=True): + continue + suffix = '/' if match.is_dir() else '' + matches.append(f'{rel}{suffix}') + if len(matches) >= self._max_find_results: + matches.append(f'[... truncated at {self._max_find_results} matches]') + break + + return '\n'.join(matches) if matches else 'No matches found.' + + @_recoverable + async def create_directory(self, path: str) -> str: + """Create a directory and any missing parents. + + Args: + path: Directory path relative to the root directory. + + Returns: + Confirmation message. + """ + resolved = self._safe_resolve(path, write=True) + resolved.mkdir(parents=True, exist_ok=True) + return f'Created directory: {path}' + + @_recoverable + async def file_info(self, path: str) -> str: + """Get metadata about a file or directory. + + Args: + path: File or directory path relative to the root directory. + + Returns: + Formatted metadata including size, type, and permissions. + """ + resolved = self._safe_resolve(path) + if not resolved.exists(): + raise FileNotFoundError(f'Path not found: {path}') + + # Check if the original (pre-resolve) path is a symlink + original = self._root / path + is_link = original.is_symlink() + + stat = resolved.stat() + kind = 'directory' if resolved.is_dir() else 'file' + size = stat.st_size + + parts = [f'path: {path}', f'type: {kind}', f'size: {size} bytes'] + + if resolved.is_file(): + raw = resolved.read_bytes() + is_bin = _is_binary(raw) + parts.append(f'binary: {is_bin}') + if not is_bin: + text = raw.decode('utf-8', errors='replace') + parts.append(f'lines: {len(text.splitlines())}') + parts.append(f'hash: {_content_hash(text)}') + + if is_link: + parts.append(f'symlink_target: {os.readlink(original)}') + + return '\n'.join(parts) diff --git a/pydantic_ai_harness/shell/README.md b/pydantic_ai_harness/shell/README.md new file mode 100644 index 0000000..f34c7fa --- /dev/null +++ b/pydantic_ai_harness/shell/README.md @@ -0,0 +1,129 @@ +# Shell + +Give an agent the ability to run shell commands, with allow/deny controls and +managed background processes. + +## The problem + +Agents frequently need to run a build, a test suite, a linter, or a quick +`grep`. Wiring up subprocess handling — streaming output, timeouts, truncation, +killing runaway processes, and cleaning up background jobs at the end of a run — +is fiddly boilerplate that every agent reinvents. + +## The solution + +`Shell` exposes command-execution tools rooted at a working directory, with +configurable allow/deny lists and automatic cleanup of background processes +when the agent run ends. + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import Shell + +agent = Agent( + 'anthropic:claude-sonnet-4-6', + capabilities=[Shell(cwd='./workspace', allowed_commands=['ls', 'cat', 'rg'])], +) + +result = agent.run_sync('List the Python files and summarize the largest one.') +print(result.output) +``` + +## Tools + +| Tool | Purpose | +|---|---| +| `run_command` | Run a command synchronously and return labelled stdout/stderr plus exit code. Honors a per-call or default timeout. | +| `start_command` | Launch a long-running command (server, watcher) in the background; returns an ID. | +| `check_command` | Report the status and accumulated output of a background command. | +| `stop_command` | Terminate a background command and return its final output. | + +Output is labelled with `[stdout]` / `[stderr]` markers and an `[exit code: N]` +line on non-zero exit. When it exceeds `max_output_chars` the **tail** is kept +(the head is dropped), so errors, stack traces, and the `[stderr]` section — +which all land at the end — survive truncation. + +## Command controls + +| Field | Effect | +|---|---| +| `allowed_commands` | If non-empty, only these executables may run (allowlist). | +| `denied_commands` | These executables are always rejected (denylist). | +| `denied_operators` | Shell operators (e.g. `>`, `>>`, `|`) that are rejected when present. | +| `allow_interactive` | If `False` (default), commands that expect a TTY (`vi`, `sudo`, `ssh`, …) are blocked. | + +`allowed_commands` and `denied_commands` are mutually exclusive — set one, not +both. `denied_commands` defaults to a list of destructive commands (`rm`, +`rmdir`, `mkfs`, `dd`, `shutdown`, `reboot`, …); pass an empty list to disable. +The executable name is extracted with `shlex`, so arguments don't bypass the +check. + +> **These checks are best-effort, not a security boundary.** A sufficiently +> motivated agent can defeat them (e.g. `bash -c '...'`, env-var indirection). +> For hard guarantees, run the agent inside OS-level isolation — a container or +> sandbox. + +## Background processes + +`start_command` writes stdout/stderr to temp files and returns a short ID. Use +`check_command(id)` to poll and `stop_command(id)` to terminate and collect +final output. Processes are launched in their own session (`start_new_session`) +so the whole process group can be signalled — `SIGTERM`, escalating to +`SIGKILL` after a grace period. + +On run end, the toolset's `__aexit__` terminates every still-running background +process and deletes its temp files. The agent runtime enters toolsets via an +`AsyncExitStack`, so this cleanup runs whether the run succeeds or raises — an +agent that forgets to call `stop_command` won't leak processes. + +## Working directory + +By default each command runs in `cwd` and `cd` has no lasting effect. Set +`persist_cwd=True` to make `cd` sticky: the toolset appends a `pwd` sentinel to +successful commands, parses the result, and carries the new directory into +subsequent calls. Commands containing `;` skip the sentinel injection so the +`&&`-gated sentinel can't be bypassed. + +## Configuration + +```python +Shell( + cwd='.', # str | Path — working directory + allowed_commands=[], # allowlist (mutually exclusive with denied) + denied_commands=[...], # denylist (defaults to destructive commands) + denied_operators=[], # blocked shell operators + default_timeout=30.0, # seconds, per run_command + max_output_chars=50_000, # output cap returned to the model + persist_cwd=False, # make cd sticky across calls + allow_interactive=False, # allow TTY-style commands +) +``` + +## Agent spec (YAML/JSON) + +`Shell` works with Pydantic AI's +[agent spec](https://ai.pydantic.dev/agent-spec/): + +```yaml +# agent.yaml +model: anthropic:claude-sonnet-4-6 +capabilities: + - Shell: + cwd: ./workspace + allowed_commands: ['ls', 'cat', 'rg', 'pytest'] +``` + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import Shell + +agent = Agent.from_file('agent.yaml', custom_capability_types=[Shell]) +``` + +Pass `custom_capability_types` so the spec loader knows how to instantiate +`Shell`. + +## Further reading + +- [Pydantic AI capabilities](https://ai.pydantic.dev/capabilities/) +- [Toolsets](https://ai.pydantic.dev/toolsets/) diff --git a/pydantic_ai_harness/shell/__init__.py b/pydantic_ai_harness/shell/__init__.py new file mode 100644 index 0000000..0a8d4be --- /dev/null +++ b/pydantic_ai_harness/shell/__init__.py @@ -0,0 +1,6 @@ +"""Shell capability: gives agents configurable command execution.""" + +from pydantic_ai_harness.shell._capability import Shell +from pydantic_ai_harness.shell._toolset import ShellToolset + +__all__ = ['Shell', 'ShellToolset'] diff --git a/pydantic_ai_harness/shell/_capability.py b/pydantic_ai_harness/shell/_capability.py new file mode 100644 index 0000000..7fa7730 --- /dev/null +++ b/pydantic_ai_harness/shell/_capability.py @@ -0,0 +1,76 @@ +"""Shell capability that provides command execution for agents.""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +from pathlib import Path + +from pydantic_ai.capabilities import AbstractCapability +from pydantic_ai.tools import AgentDepsT +from pydantic_ai.toolsets import AgentToolset + +from pydantic_ai_harness.shell._toolset import ShellToolset + +_DEFAULT_DENIED_COMMANDS: list[str] = [ + 'rm', + 'rmdir', + 'mkfs', + 'dd', + 'format', + 'shutdown', + 'reboot', + 'halt', + 'poweroff', + 'init', +] + + +@dataclass +class Shell(AbstractCapability[AgentDepsT]): + """Shell command execution for agents. + + Commands execute in a subprocess rooted at `cwd`. Use `allowed_commands` + or `denied_commands` to control what the agent can invoke. + """ + + cwd: str | Path = '.' + """Working directory for command execution.""" + + allowed_commands: Sequence[str] = field(default_factory=list[str]) + """If non-empty, only these command names may be executed (allowlist).""" + + denied_commands: Sequence[str] = field(default_factory=lambda: list(_DEFAULT_DENIED_COMMANDS)) + """These command names are always rejected (denylist). + + Defaults to blocking destructive commands (rm, dd, shutdown, etc.). + Set to an empty list to disable. + """ + + denied_operators: Sequence[str] = field(default_factory=list[str]) + """Shell operators that are blocked (e.g. '>', '>>', '|' for restrictive mode).""" + + default_timeout: float = 30.0 + """Default timeout in seconds for command execution.""" + + max_output_chars: int = 50_000 + """Maximum characters of output returned to the model.""" + + persist_cwd: bool = False + """If True, track cd commands and adjust the working directory for subsequent calls.""" + + allow_interactive: bool = False + """If True, allow interactive commands (vi, nano, ssh, etc.). Blocked by default.""" + + def get_toolset(self) -> AgentToolset[AgentDepsT]: + """Build and return the shell toolset.""" + return ShellToolset[AgentDepsT]( + cwd=Path(self.cwd), + allowed_commands=self.allowed_commands, + denied_commands=self.denied_commands, + denied_operators=self.denied_operators, + default_timeout=self.default_timeout, + max_output_chars=self.max_output_chars, + persist_cwd=self.persist_cwd, + allow_interactive=self.allow_interactive, + ) diff --git a/pydantic_ai_harness/shell/_toolset.py b/pydantic_ai_harness/shell/_toolset.py new file mode 100644 index 0000000..963e1f4 --- /dev/null +++ b/pydantic_ai_harness/shell/_toolset.py @@ -0,0 +1,500 @@ +"""Shell toolset — gives agents the ability to run commands.""" + +from __future__ import annotations + +import functools +import os +import re +import shlex +import signal +import subprocess +import tempfile +import uuid +from collections.abc import Awaitable, Callable, Sequence +from pathlib import Path +from typing import Any, Concatenate, ParamSpec + +import anyio +import anyio.abc +from pydantic_ai import RunContext +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.tools import AgentDepsT +from pydantic_ai.toolsets import AbstractToolset, FunctionToolset + +_IO_DRAIN_TIMEOUT: float = 2.0 +_KILL_GRACE_PERIOD: float = 2.0 + +_P = ParamSpec('_P') + + +def _recoverable( + fn: Callable[Concatenate[ShellToolset, _P], Awaitable[str]], +) -> Callable[Concatenate[ShellToolset, _P], Awaitable[str]]: + """Convert model-correctable errors into `ModelRetry`. + + pyai only feeds `ModelRetry` back to the model as a retry prompt; any other + exception propagates and aborts the whole run. A denied command is something + the model can recover from (pick an allowed one), so surface it as a retry + instead of crashing the agent. + """ + + @functools.wraps(fn) + async def wrapper(self: ShellToolset, *args: _P.args, **kwargs: _P.kwargs) -> str: + try: + return await fn(self, *args, **kwargs) + except PermissionError as e: + raise ModelRetry(str(e)) from e + + return wrapper + + +def _is_interactive_command(command: str) -> bool: + """Detect commands that typically require interactive input.""" + interactive_patterns = [ + r'^(vi|vim|nano|emacs|less|more|top|htop|man)\b', + r'^sudo\s', + r'^passwd\b', + r'^ssh\b', + r'^telnet\b', + r'^ftp\b', + ] + return any(re.match(p, command.strip()) for p in interactive_patterns) + + +class _BackgroundProcess: + """State for a background command using temp files for output.""" + + __slots__ = ('proc', 'command', 'stdout_path', 'stderr_path', 'finished', 'exit_code') + + def __init__( + self, + proc: anyio.abc.Process, + command: str, + stdout_path: str, + stderr_path: str, + ) -> None: + self.proc = proc + self.command = command + self.stdout_path = stdout_path + self.stderr_path = stderr_path + self.finished = False + self.exit_code: int | None = None + + +class ShellToolset(FunctionToolset[AgentDepsT]): + """Gives an agent the ability to execute shell commands. + + Supports synchronous execution (run_command) and background processes + (start_command / check_command / stop_command). Output is streamed, + truncated to fit model context, and labelled with stdout/stderr/exit code. + + Optionally tracks the working directory across calls so ``cd`` persists. + """ + + def __init__( + self, + *, + cwd: Path, + allowed_commands: Sequence[str], + denied_commands: Sequence[str], + denied_operators: Sequence[str], + default_timeout: float, + max_output_chars: int, + persist_cwd: bool, + allow_interactive: bool, + ) -> None: + super().__init__() + self._cwd = cwd.resolve() + # The configured starting directory, never mutated by persist_cwd, so + # `for_run` can hand each run a fresh instance rooted back here. + self._initial_cwd = self._cwd + self._allowed_commands = list(allowed_commands) + self._denied_commands = list(denied_commands) + self._denied_operators = list(denied_operators) + self._default_timeout = default_timeout + self._max_output_chars = max_output_chars + self._persist_cwd = persist_cwd + self._allow_interactive = allow_interactive + self._background: dict[str, _BackgroundProcess] = {} + + if self._allowed_commands and self._denied_commands: + raise ValueError('Specify allowed_commands or denied_commands, not both.') + + self.add_function(self.run_command, name='run_command') + self.add_function(self.start_command, name='start_command') + self.add_function(self.check_command, name='check_command') + self.add_function(self.stop_command, name='stop_command') + + async def for_run(self, ctx: RunContext[AgentDepsT]) -> AbstractToolset[AgentDepsT]: + """Return a fresh instance per run so cwd and background processes are isolated. + + `get_toolset` builds one shared instance at agent construction (see + `AbstractToolset.for_run`, which defaults to returning `self`). This + toolset holds mutable per-run state (`_cwd`, `_background`), so without + an override two concurrent runs would corrupt each other's cwd and kill + each other's background processes. + """ + return ShellToolset( + cwd=self._initial_cwd, + allowed_commands=self._allowed_commands, + denied_commands=self._denied_commands, + denied_operators=self._denied_operators, + default_timeout=self._default_timeout, + max_output_chars=self._max_output_chars, + persist_cwd=self._persist_cwd, + allow_interactive=self._allow_interactive, + ) + + async def __aexit__(self, *args: Any) -> None: + """Terminate all remaining background processes and clean up temp files.""" + for bg in self._background.values(): + if not bg.finished: + await self._kill_process_group(bg.proc) + with anyio.CancelScope(shield=True): + await bg.proc.wait() + await bg.proc.aclose() + self._cleanup_bg_files(bg) + self._background.clear() + + def _first_denied_operator(self, command: str) -> str | None: + """Return the first denied operator found in command, or None.""" + return next((op for op in self._denied_operators if op in command), None) + + def _check_command(self, command: str) -> None: + """Validate command against allow/deny lists. + + These checks are best-effort and are not a security boundary — a + sufficiently motivated agent can bypass them. Use OS-level isolation + (containers, sandboxes) for hard enforcement. + """ + if not self._allow_interactive and _is_interactive_command(command): + raise PermissionError(f'Interactive commands are not allowed. Command: {command!r}') + + matched_op = self._first_denied_operator(command) + if matched_op: + raise PermissionError(f'Shell operator {matched_op!r} is not allowed.') + + try: + tokens = shlex.split(command) + except ValueError: + return + if not tokens: + return + executable = tokens[0] + + if self._denied_commands and executable in self._denied_commands: + raise PermissionError(f'Command {executable!r} is denied.') + if self._allowed_commands and executable not in self._allowed_commands: + raise PermissionError(f'Command {executable!r} is not in the allowed list.') + + def _truncate(self, text: str) -> str: + """Truncate output to the configured cap, keeping the tail. + + The most useful output — errors, stack traces, exit info, and the + `[stderr]` section (which callers append last) — lands at the end, so + the head is dropped and the final `max_output_chars` are kept. + """ + if len(text) <= self._max_output_chars: + return text + marker = f'[... output truncated, showing last {self._max_output_chars} chars]\n' + return marker + text[-self._max_output_chars :] + + def _build_cwd_capture(self, command: str) -> tuple[str, Path | None]: + """Wrap a command to record its final working directory out-of-band. + + `pwd` is written to a private temp file whose random path the agent's + command can't address, so command output can never spoof the tracked + cwd — unlike parsing a sentinel out of stdout, where any command that + prints the sentinel string (or one using `;` to skip success-gating) + could redirect the cwd. Returns the wrapped command plus the temp-file + path, or the command unchanged and `None` when cwd tracking is off. + """ + if not self._persist_cwd: + return command, None + fd, name = tempfile.mkstemp(prefix='harness_cwd_') + os.close(fd) + wrapped = f'{command}\n__harness_ec=$?\npwd > {shlex.quote(name)}\nexit $__harness_ec' + return wrapped, Path(name) + + def _apply_captured_cwd(self, cwd_file: Path) -> None: + """Update the persistent cwd from the capture file, ignoring junk.""" + try: + recorded = cwd_file.read_text(encoding='utf-8').strip() + except OSError: # pragma: no cover + return + if not recorded: + return + candidate = Path(recorded) + if candidate.is_dir(): + self._cwd = candidate + + async def _kill_process_group(self, proc: anyio.abc.Process) -> None: + """SIGTERM the process group, escalating to SIGKILL after the grace period.""" + pid = proc.pid + try: + os.killpg(os.getpgid(pid), signal.SIGTERM) + except (ProcessLookupError, PermissionError, OSError): + return + + with anyio.move_on_after(_KILL_GRACE_PERIOD): + await proc.wait() + return + + # Still alive after grace period — hard kill + try: + os.killpg(os.getpgid(pid), signal.SIGKILL) + except (ProcessLookupError, PermissionError, OSError): + pass + + async def _drain_with_timeout( + self, + stdout_chunks: list[bytes], + stderr_chunks: list[bytes], + proc: anyio.abc.Process, + ) -> None: + """Drain remaining pipe data after kill (grandchildren may still hold the pipe).""" + + async def _drain_stdout() -> None: + if proc.stdout is None: + return + try: + async for chunk in proc.stdout: + stdout_chunks.append(chunk) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + pass + + async def _drain_stderr() -> None: + if proc.stderr is None: + return + try: + async for chunk in proc.stderr: + stderr_chunks.append(chunk) + except (anyio.ClosedResourceError, anyio.BrokenResourceError): + pass + + with anyio.move_on_after(_IO_DRAIN_TIMEOUT): + async with anyio.create_task_group() as tg: + tg.start_soon(_drain_stdout) + tg.start_soon(_drain_stderr) + + @_recoverable + async def run_command(self, command: str, *, timeout_seconds: float | None = None) -> str: + """Execute a shell command and return its output. + + Args: + command: The shell command to run. + timeout_seconds: Maximum seconds to wait (default: 30). + + Returns: + Labeled stdout/stderr output with exit code on non-zero exit. + """ + self._check_command(command) + timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout + + actual_command, cwd_file = self._build_cwd_capture(command) + try: + proc = await anyio.open_process( + actual_command, + cwd=self._cwd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + start_new_session=True, + ) + stdout_chunks: list[bytes] = [] + stderr_chunks: list[bytes] = [] + try: + assert proc.stdout is not None + assert proc.stderr is not None + + async def _read_stdout() -> None: + assert proc.stdout is not None + async for chunk in proc.stdout: + stdout_chunks.append(chunk) + + async def _read_stderr() -> None: + assert proc.stderr is not None + async for chunk in proc.stderr: + stderr_chunks.append(chunk) + + with anyio.fail_after(timeout): + async with anyio.create_task_group() as tg: + tg.start_soon(_read_stdout) + tg.start_soon(_read_stderr) + await proc.wait() + except TimeoutError: + await self._kill_process_group(proc) + with anyio.CancelScope(shield=True): + await proc.wait() + await self._drain_with_timeout(stdout_chunks, stderr_chunks, proc) + return f'[Command timed out after {timeout}s]' + finally: + await proc.aclose() + + stdout = b''.join(stdout_chunks).decode('utf-8', errors='replace') + stderr = b''.join(stderr_chunks).decode('utf-8', errors='replace') + + parts: list[str] = [] + if stdout: + parts.append(f'[stdout]\n{stdout}') + if stderr: + parts.append(f'[stderr]\n{stderr}') + output = '\n'.join(parts) if parts else '(no output)' + + output = self._truncate(output) + exit_code = proc.returncode if proc.returncode is not None else 0 + + if cwd_file is not None and exit_code == 0: + self._apply_captured_cwd(cwd_file) + + if exit_code != 0: + return f'{output}\n[exit code: {exit_code}]' + return output + finally: + if cwd_file is not None: + cwd_file.unlink(missing_ok=True) + + @_recoverable + async def start_command(self, command: str) -> str: + """Start a long-running command in the background (e.g. a server or watcher). + + Callers MUST call `stop_command(command_id)` when done to terminate the + process and clean up temporary output files. + + Args: + command: The shell command to run in the background. + + Returns: + A message containing the unique command ID for later check/stop calls. + """ + self._check_command(command) + command_id = uuid.uuid4().hex[:12] + + stdout_file = tempfile.NamedTemporaryFile(mode='w+b', prefix=f'harness_{command_id}_out_', delete=False) + stderr_file = tempfile.NamedTemporaryFile(mode='w+b', prefix=f'harness_{command_id}_err_', delete=False) + + try: + proc = await anyio.open_process( + command, + cwd=self._cwd, + stdout=stdout_file, + stderr=stderr_file, + start_new_session=True, + ) + except BaseException: + stdout_file.close() + stderr_file.close() + os.unlink(stdout_file.name) + os.unlink(stderr_file.name) + raise + + stdout_file.close() + stderr_file.close() + + bg = _BackgroundProcess( + proc=proc, + command=command, + stdout_path=stdout_file.name, + stderr_path=stderr_file.name, + ) + self._background[command_id] = bg + + return f'Started background command: {command!r}\nID: {command_id}' + + def _read_bg_output(self, bg: _BackgroundProcess) -> tuple[str, str]: + """Read current output from background process temp files.""" + try: + stdout = Path(bg.stdout_path).read_text(encoding='utf-8', errors='replace') + except OSError: + stdout = '' + try: + stderr = Path(bg.stderr_path).read_text(encoding='utf-8', errors='replace') + except OSError: + stderr = '' + return stdout, stderr + + def _cleanup_bg_files(self, bg: _BackgroundProcess) -> None: + """Remove temp files for a background process.""" + try: + os.unlink(bg.stdout_path) + except OSError: + pass + try: + os.unlink(bg.stderr_path) + except OSError: + pass + + async def check_command(self, command_id: str) -> str: + """Check the status and recent output of a background command. + + Args: + command_id: The ID returned by start_command. + + Returns: + Status and recent output of the background command. + """ + bg = self._background.get(command_id) + if bg is None: + return f'[Error: unknown command ID {command_id!r}]' + + if not bg.finished and bg.proc.returncode is not None: + bg.exit_code = bg.proc.returncode + bg.finished = True + + stdout, stderr = self._read_bg_output(bg) + + status = 'finished' if bg.finished else 'running' + parts = [f'[status: {status}]'] + if bg.finished and bg.exit_code is not None: + parts.append(f'[exit code: {bg.exit_code}]') + output_sections: list[str] = [] + if stdout: + output_sections.append(f'[stdout]\n{stdout}') + if stderr: + output_sections.append(f'[stderr]\n{stderr}') + if output_sections: + parts.append(self._truncate('\n'.join(output_sections))) + else: + parts.append('(no output yet)') + + return '\n'.join(parts) + + async def stop_command(self, command_id: str) -> str: + """Stop a background command and return its final output. + + Args: + command_id: The ID returned by start_command. + + Returns: + Final output and exit status of the stopped command. + """ + bg = self._background.get(command_id) + if bg is None: + return f'[Error: unknown command ID {command_id!r}]' + + if not bg.finished: + await self._kill_process_group(bg.proc) + with anyio.CancelScope(shield=True): + await bg.proc.wait() + bg.exit_code = bg.proc.returncode + bg.finished = True + + stdout, stderr = self._read_bg_output(bg) + + self._cleanup_bg_files(bg) + del self._background[command_id] + await bg.proc.aclose() + + parts = [f'[stopped: {bg.command!r}]'] + if bg.exit_code is not None: + parts.append(f'[exit code: {bg.exit_code}]') + output_sections: list[str] = [] + if stdout: + output_sections.append(f'[stdout]\n{stdout}') + if stderr: + output_sections.append(f'[stderr]\n{stderr}') + if output_sections: + parts.append(self._truncate('\n'.join(output_sections))) + else: + parts.append('(no output)') + + return '\n'.join(parts) diff --git a/pyproject.toml b/pyproject.toml index bbc43f0..7b94dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ quote-style = 'single' [tool.pyright] pythonVersion = '3.10' typeCheckingMode = 'strict' -exclude = ['template', '.venv'] +exclude = ['template', '.venv', 'mutants'] executionEnvironments = [ { root = 'tests', reportPrivateUsage = false }, ] @@ -140,3 +140,15 @@ exclude_lines = [ 'assert_never', 'if TYPE_CHECKING:', ] + +[tool.mutmut] +paths_to_mutate = [ + 'pydantic_ai_harness/filesystem/_toolset.py', + 'pydantic_ai_harness/shell/_toolset.py', +] +tests_dir = ['tests/filesystem/', 'tests/shell/'] +also_copy = ['pydantic_ai_harness/', 'tests/'] +# Skip trio-parametrized tests during mutation testing — trio segfaults in +# mutmut's subprocess environment on Python 3.14 (not a code bug). +pytest_add_cli_args = ['-k', 'not trio'] +# See docs/mutation-testing.md for full results (89.7% kill rate, 60 equivalent mutants). diff --git a/scripts/run-mutmut.sh b/scripts/run-mutmut.sh new file mode 100755 index 0000000..ec8b13a --- /dev/null +++ b/scripts/run-mutmut.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# One-off mutation testing runner. +# +# mutmut is intentionally not a project dev dependency: it pulls in a large +# tree and is only needed when validating test quality. Install it ephemerally +# via `uv run --with` and invoke it as a subcommand. +# +# Config (paths_to_mutate, tests_dir, also_copy, pytest_add_cli_args) lives in +# [tool.mutmut] in pyproject.toml — mutmut v3 reads it from CWD by default. +# +# Usage: +# scripts/run-mutmut.sh # run all mutants +# scripts/run-mutmut.sh results # show pass/fail summary +# scripts/run-mutmut.sh show # inspect a specific mutant +# scripts/run-mutmut.sh --max-children 4 run # any mutmut flag works +# +# Pair with `make testcov` to keep coverage at 100% — surviving mutants usually +# indicate missing test cases for boundary conditions. + +set -euo pipefail + +cd "$(dirname "$0")/.." + +uv run --with "mutmut>=3.5.0" -- mutmut "$@" diff --git a/tests/filesystem/__init__.py b/tests/filesystem/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/filesystem/test_filesystem.py b/tests/filesystem/test_filesystem.py new file mode 100644 index 0000000..cae0704 --- /dev/null +++ b/tests/filesystem/test_filesystem.py @@ -0,0 +1,1122 @@ +"""Tests for the FileSystem capability and FileSystemToolset.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest +from pydantic_ai import Agent +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.models.test import TestModel + +from pydantic_ai_harness.filesystem import FileSystem +from pydantic_ai_harness.filesystem._toolset import FileSystemToolset, _content_hash, _format_lines, _is_binary + + +class TestFormatLines: + def test_basic_formatting(self) -> None: + text = 'line1\nline2\nline3\n' + result = _format_lines(text.splitlines(keepends=True), 0, 10) + assert ' 1\tline1\n' in result + assert ' 2\tline2\n' in result + assert ' 3\tline3\n' in result + + def test_offset(self) -> None: + text = 'a\nb\nc\nd\ne\n' + result = _format_lines(text.splitlines(keepends=True), 2, 2) + assert ' 3\tc\n' in result + assert ' 4\td\n' in result + assert '... (1 more lines. Use offset=4 to continue reading.)' in result + + def test_offset_exceeds_length(self) -> None: + text = 'a\nb\n' + with pytest.raises(ValueError, match='Offset 5 exceeds file length'): + _format_lines(text.splitlines(keepends=True), 5, 10) + + def test_empty_file(self) -> None: + result = _format_lines([], 0, 10) + assert result == '(empty file)\n' + + def test_no_trailing_newline(self) -> None: + text = 'no newline' + result = _format_lines(text.splitlines(keepends=True), 0, 10) + assert result.endswith('\n') + + def test_continuation_hint(self) -> None: + text = '\n'.join(f'line{i}' for i in range(10)) + result = _format_lines(text.splitlines(keepends=True), 0, 3) + assert '... (7 more lines. Use offset=3 to continue reading.)' in result + + +class TestIsBinary: + def test_text_content(self) -> None: + assert _is_binary(b'hello world\n') is False + + def test_binary_content(self) -> None: + assert _is_binary(b'hello\x00world') is True + + def test_null_after_sample(self) -> None: + data = b'x' * 9000 + b'\x00' + assert _is_binary(data) is False + + def test_null_at_boundary(self) -> None: + data = b'x' * 8191 + b'\x00' + assert _is_binary(data) is True + + def test_empty(self) -> None: + assert _is_binary(b'') is False + + +class TestContentHash: + def test_deterministic(self) -> None: + assert _content_hash('hello') == _content_hash('hello') + + def test_different_content(self) -> None: + assert _content_hash('hello') != _content_hash('world') + + def test_length(self) -> None: + assert len(_content_hash('test')) == 12 + + +@pytest.fixture +def fs_root(tmp_path: Path) -> Path: + (tmp_path / 'hello.txt').write_text('Hello, world!\n') + (tmp_path / 'multi.txt').write_text('line1\nline2\nline3\nline4\nline5\n') + (tmp_path / 'subdir').mkdir() + (tmp_path / 'subdir' / 'nested.py').write_text('print("nested")\n') + (tmp_path / '.hidden').write_text('secret\n') + (tmp_path / 'binary.bin').write_bytes(b'\x00\x01\x02\x03') + (tmp_path / '.git').mkdir() + (tmp_path / '.git' / 'config').write_text('[core]\n') + (tmp_path / '.env').write_text('SECRET_KEY=abc123\n') + return tmp_path + + +@pytest.fixture +def toolset(fs_root: Path) -> FileSystemToolset[None]: + return FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=['.git/*', '.env', '.env.*'], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + + +class TestPathSecurity: + async def test_traversal_with_dotdot(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(PermissionError, match='resolves outside'): + toolset._resolve_path('../../../etc/passwd') + + async def test_traversal_absolute_path(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(PermissionError, match='resolves outside'): + toolset._resolve_path('/etc/passwd') + + async def test_traversal_encoded(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(PermissionError, match='resolves outside'): + toolset._resolve_path('subdir/../../..') + + async def test_symlink_escape(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + """Symlink pointing outside root is rejected.""" + target = fs_root.parent / 'symlink-escape-target' + target.write_text('escaped!\n') + try: + link = fs_root / 'escape_link' + link.symlink_to(target) + with pytest.raises(PermissionError, match='resolves outside'): + toolset._resolve_path('escape_link') + finally: + target.unlink(missing_ok=True) + + async def test_valid_path_resolves(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + result = toolset._resolve_path('hello.txt') + assert result == (fs_root / 'hello.txt').resolve() + + def test_first_matching_pattern_match(self, toolset: FileSystemToolset[None]) -> None: + result = toolset._first_matching_pattern('secret.key', ['*.txt', '*.key']) + assert result == '*.key' + + def test_first_matching_pattern_no_match(self, toolset: FileSystemToolset[None]) -> None: + result = toolset._first_matching_pattern('readme.md', ['*.txt', '*.key']) + assert result is None + + def test_first_matching_pattern_empty(self, toolset: FileSystemToolset[None]) -> None: + result = toolset._first_matching_pattern('anything.py', []) + assert result is None + + async def test_nested_path_resolves(self, toolset: FileSystemToolset[None]) -> None: + result = toolset._resolve_path('subdir/nested.py') + assert result.name == 'nested.py' + + +class TestAccessPatterns: + async def test_denied_pattern_blocks(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['*.secret'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + with pytest.raises(PermissionError, match='denied by pattern'): + ts._check_access('data.secret') + + async def test_denied_pattern_passes_non_matching(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['*.secret'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + # Path that doesn't match any denied pattern should pass + ts._check_access('data.txt') + + async def test_allowed_pattern_permits(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=['*.py'], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + # Should not raise for .py files + ts._check_access('test.py') + + async def test_allowed_pattern_blocks_non_matching(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=['*.py'], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + with pytest.raises(PermissionError, match='does not match any allowed'): + ts._check_access('data.txt') + + async def test_protected_pattern_blocks_write(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(PermissionError, match='protected'): + toolset._check_access('.git/config', write=True) + + async def test_protected_pattern_allows_read(self, toolset: FileSystemToolset[None]) -> None: + # Should not raise for read + toolset._check_access('.git/config', write=False) + + async def test_env_file_protected(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(PermissionError, match='protected'): + toolset._check_access('.env', write=True) + + async def test_write_non_protected_with_patterns_configured(self, toolset: FileSystemToolset[None]) -> None: + # write=True on a path that doesn't match any protected pattern should pass + toolset._check_access('hello.txt', write=True) + + async def test_access_with_no_denied_patterns(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + # No denied, no protected, no allowed → should pass for any path + ts._check_access('anything.txt', write=True) + + async def test_is_accessible_no_patterns(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + assert ts._is_accessible('anything.txt') + assert ts._is_accessible('anything.txt', write=True) + + async def test_is_accessible_protected_only_on_write(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=['.env', '.env.*'], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + # Reads ignore the protected list — they only block writes. + assert ts._is_accessible('.env') + assert ts._is_accessible('.env', write=True) is False + # A non-protected path passes the protected check even with write=True, + # so the walker falls through to the allowed/denied check. + assert ts._is_accessible('hello.txt', write=True) + + async def test_is_accessible_denied(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['*.secret'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + assert ts._is_accessible('visible.txt') + assert ts._is_accessible('creds.secret') is False + + async def test_is_accessible_allowed_list_excludes(self, fs_root: Path) -> None: + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=['*.py'], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + assert ts._is_accessible('main.py') + assert ts._is_accessible('README.md') is False + + +class TestReadFile: + async def test_read_basic(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.read_file('hello.txt') + assert 'Hello, world!' in result + assert 'hash:' in result + assert '1 lines' in result + + async def test_read_with_offset(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.read_file('multi.txt', offset=2) + assert 'line3' in result + assert 'line1' not in result + + async def test_read_with_limit(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.read_file('multi.txt', limit=2) + assert 'line1' in result + assert 'line2' in result + assert '... (3 more lines' in result + + async def test_read_directory_raises(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='is a directory'): + await toolset.read_file('subdir') + + async def test_read_missing_raises(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='File not found'): + await toolset.read_file('nonexistent.txt') + + async def test_read_binary_file(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.read_file('binary.bin') + assert 'Binary file' in result + assert '4 bytes' in result + + async def test_read_traversal_blocked(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry): + await toolset.read_file('../../../etc/passwd') + + +class TestWriteFile: + async def test_write_new_file(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + result = await toolset.write_file('new.txt', 'new content\n') + assert 'Wrote' in result + assert (fs_root / 'new.txt').read_text() == 'new content\n' + + async def test_write_nonexistent_parent_raises(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match="Parent directory 'deep/nested' does not exist"): + await toolset.write_file('deep/nested/file.txt', 'deep\n') + + async def test_write_overwrite(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + await toolset.write_file('hello.txt', 'overwritten\n') + assert (fs_root / 'hello.txt').read_text() == 'overwritten\n' + + async def test_write_conflict_detection(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + # Get current hash + content = (fs_root / 'hello.txt').read_text() + current_hash = _content_hash(content) + + # Write with correct hash succeeds + await toolset.write_file('hello.txt', 'updated\n', expected_hash=current_hash) + assert (fs_root / 'hello.txt').read_text() == 'updated\n' + + async def test_write_conflict_rejection(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + with pytest.raises(ModelRetry, match='Conflict'): + await toolset.write_file('hello.txt', 'bad\n', expected_hash='wrong_hash_x') + + async def test_write_protected_blocked(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='protected'): + await toolset.write_file('.env', 'HACKED=true\n') + + async def test_write_returns_hash(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.write_file('hashed.txt', 'content\n') + assert 'hash:' in result + + +class TestEditFile: + async def test_edit_basic(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + result = await toolset.edit_file('hello.txt', 'Hello, world!', 'Hello, universe!') + assert 'Edited' in result + assert (fs_root / 'hello.txt').read_text() == 'Hello, universe!\n' + + async def test_edit_not_found_text(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='old_text not found'): + await toolset.edit_file('hello.txt', 'NONEXISTENT', 'replacement') + + async def test_edit_ambiguous_match(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + (fs_root / 'repeat.txt').write_text('foo bar foo\n') + with pytest.raises(ModelRetry, match='found 2 times'): + await toolset.edit_file('repeat.txt', 'foo', 'baz') + + async def test_edit_missing_file(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='File not found'): + await toolset.edit_file('ghost.txt', 'x', 'y') + + async def test_edit_conflict_detection(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + content = (fs_root / 'hello.txt').read_text() + current_hash = _content_hash(content) + result = await toolset.edit_file('hello.txt', 'Hello', 'Hi', expected_hash=current_hash) + assert 'hash:' in result + + async def test_edit_conflict_rejection(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='Conflict'): + await toolset.edit_file('hello.txt', 'Hello', 'Hi', expected_hash='stale_hash_') + + async def test_edit_protected_blocked(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='protected'): + await toolset.edit_file('.env', 'SECRET', 'HACKED') + + async def test_edit_returns_new_hash(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.edit_file('hello.txt', 'Hello, world!', 'Goodbye!') + assert 'hash:' in result + + +class TestListDirectory: + async def test_list_root(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('.') + assert 'hello.txt' in result + assert 'subdir/' in result + + async def test_list_subdir(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('subdir') + assert 'nested.py' in result + + async def test_list_not_a_dir(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry): + await toolset.list_directory('hello.txt') + + async def test_list_skips_hidden(self, toolset: FileSystemToolset[None]) -> None: + # Dotfiles/dot-directories are hidden, matching find_files/search_files. + result = await toolset.list_directory('.') + assert 'hello.txt' in result + assert '.hidden' not in result + assert '.git' not in result + + async def test_list_shows_sizes(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('.') + assert 'bytes' in result + + async def test_list_shows_dir_indicator(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('.') + assert 'subdir/' in result + + async def test_list_empty_directory(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + (fs_root / 'empty').mkdir() + result = await toolset.list_directory('empty') + assert result == '(empty directory)' + + async def test_list_hides_protected_entries(self, fs_root: Path) -> None: + # .env is protected by the default toolset fixture; .git is hidden by + # the dotfile filter, but a directory that is itself explicitly + # protected is also hidden from listings. + (fs_root / 'visible.txt').write_text('ok\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=['.env', '.env.*'], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.list_directory('.') + assert 'visible.txt' in result + assert '.env' not in result + + async def test_list_root_allowed_patterns_filters_entries(self, fs_root: Path) -> None: + # A file-shaped allowed pattern must not make the root unlistable: '.' + # is always listed, and entries are filtered against the pattern. + (fs_root / 'keep.py').write_text('ok\n') + (fs_root / 'skip.md').write_text('ok\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=['*.py'], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.list_directory('.') + assert 'keep.py' in result + assert 'skip.md' not in result + + async def test_list_hides_denied_entries(self, fs_root: Path) -> None: + (fs_root / 'visible.txt').write_text('ok\n') + (fs_root / 'creds.secret').write_text('hunter2\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['*.secret'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.list_directory('.') + assert 'visible.txt' in result + assert 'creds.secret' not in result + + +class TestSearchFiles: + async def test_search_basic(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('Hello') + assert 'hello.txt:1:Hello, world!' in result + + async def test_search_regex(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files(r'line\d') + assert 'multi.txt' in result + + async def test_search_no_matches(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('ZZZZNOTHERE') + assert result == 'No matches found.' + + async def test_search_skips_hidden(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('secret') + assert '.hidden' not in result + + async def test_search_skips_binary(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('.') + assert 'binary.bin' not in result + + async def test_search_invalid_regex(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='Invalid regex'): + await toolset.search_files('[invalid') + + async def test_search_include_glob(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('print', include_glob='*.py') + assert 'nested.py' in result + + async def test_search_include_glob_excludes(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('Hello', include_glob='*.py') + assert result == 'No matches found.' + + async def test_search_in_specific_file(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files('line', path='multi.txt') + assert 'multi.txt' in result + + async def test_search_truncation(self, fs_root: Path) -> None: + # Create many matching files + for i in range(20): + (fs_root / f'match{i}.txt').write_text('findme\n' * 100) + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=50, + max_find_results=1000, + ) + result = await ts.search_files('findme') + assert 'truncated at 50 matches' in result + + async def test_search_skips_protected_contents(self, fs_root: Path) -> None: + # The .env file has matching content but should be filtered by the + # recursive walker before its bytes are read. + (fs_root / 'visible.txt').write_text('SECRET=matchme\n') + (fs_root / '.env').write_text('SECRET=matchme\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=['.env', '.env.*'], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.search_files('matchme') + assert 'visible.txt' in result + assert '.env' not in result + + async def test_search_skips_denied_files(self, fs_root: Path) -> None: + (fs_root / 'visible.txt').write_text('lookhere\n') + (fs_root / 'creds.secret').write_text('lookhere\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['*.secret'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.search_files('lookhere') + assert 'visible.txt' in result + assert 'creds.secret' not in result + + async def test_search_only_matches_allowed_files(self, fs_root: Path) -> None: + # The search root ('.') isn't required to match allowed_patterns; only + # the matched files are filtered against it per-entry. + (fs_root / 'keep.py').write_text('findme\n') + (fs_root / 'skip.md').write_text('findme\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=['*.py'], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.search_files('findme') + assert 'keep.py' in result + assert 'skip.md' not in result + + +class TestFindFiles: + async def test_find_glob(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('*.txt') + assert 'hello.txt' in result + assert 'multi.txt' in result + + async def test_find_recursive(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('**/*.py') + assert 'nested.py' in result + + async def test_find_no_matches(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('*.xyz') + assert result == 'No matches found.' + + async def test_find_skips_hidden(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('*') + assert '.hidden' not in result + assert '.git' not in result + + async def test_find_not_a_dir(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry): + await toolset.find_files('*.txt', path='hello.txt') + + async def test_find_in_subdir(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('*.py', path='subdir') + assert 'nested.py' in result + + async def test_find_directories(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('sub*') + assert 'subdir/' in result + + async def test_find_truncation(self, fs_root: Path) -> None: + for i in range(20): + (fs_root / f'file{i}.dat').write_text(f'{i}\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=5, + ) + result = await ts.find_files('*.dat') + assert 'truncated at 5 matches' in result + + async def test_find_hides_protected_entries(self, fs_root: Path) -> None: + (fs_root / 'visible.txt').write_text('ok\n') + (fs_root / '.env').write_text('SECRET=abc\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=['.env', '.env.*'], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.find_files('*') + assert 'visible.txt' in result + assert '.env' not in result + + async def test_find_hides_denied_entries(self, fs_root: Path) -> None: + (fs_root / 'visible.txt').write_text('ok\n') + (fs_root / 'creds.secret').write_text('hunter2\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['*.secret'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.find_files('*') + assert 'visible.txt' in result + assert 'creds.secret' not in result + + async def test_find_only_shows_allowed_entries(self, fs_root: Path) -> None: + # The find root ('.') isn't required to match allowed_patterns; only + # the matched entries are filtered against it per-entry. + (fs_root / 'keep.py').write_text('ok\n') + (fs_root / 'skip.md').write_text('ok\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=['*.py'], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + result = await ts.find_files('*') + assert 'keep.py' in result + assert 'skip.md' not in result + + +class TestCreateDirectory: + async def test_create_basic(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + result = await toolset.create_directory('newdir') + assert 'Created directory' in result + assert (fs_root / 'newdir').is_dir() + + async def test_create_nested(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + await toolset.create_directory('a/b/c') + assert (fs_root / 'a' / 'b' / 'c').is_dir() + + async def test_create_existing_ok(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.create_directory('subdir') + assert 'Created directory' in result + + async def test_create_protected_blocked(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='protected'): + await toolset.create_directory('.git/hooks') + + +class TestFileInfo: + async def test_info_file(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('hello.txt') + assert 'type: file' in result + assert 'size:' in result + assert 'lines:' in result + assert 'hash:' in result + assert 'binary: False' in result + + async def test_info_directory(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('subdir') + assert 'type: directory' in result + + async def test_info_binary(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('binary.bin') + assert 'binary: True' in result + assert 'lines:' not in result + + async def test_info_not_found(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='Path not found'): + await toolset.file_info('nonexistent') + + async def test_info_symlink(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + link = fs_root / 'link.txt' + link.symlink_to(fs_root / 'hello.txt') + result = await toolset.file_info('link.txt') + assert 'type: file' in result + assert 'symlink_target:' in result + + +class TestMutationKillers: + async def test_format_lines_offset_equals_total(self) -> None: + text = 'a\nb\n' # 2 lines + with pytest.raises(ValueError, match='Offset 2 exceeds file length'): + _format_lines(text.splitlines(keepends=True), 2, 10) + + async def test_format_lines_exact_fit_no_continuation(self) -> None: + text = 'a\nb\nc\n' # 3 lines + result = _format_lines(text.splitlines(keepends=True), 0, 3) + assert '... (' not in result + assert 'more lines' not in result + + async def test_format_lines_exact_fit_from_offset(self) -> None: + text = 'a\nb\nc\n' # 3 lines + result = _format_lines(text.splitlines(keepends=True), 1, 2) # lines 2-3, 0 remaining + assert '... (' not in result + assert 'more lines' not in result + + async def test_format_lines_one_line_remaining(self) -> None: + text = 'a\nb\nc\n' # 3 lines + result = _format_lines(text.splitlines(keepends=True), 0, 2) + assert '... (1 more lines. Use offset=2 to continue reading.)' in result + + async def test_format_lines_line_number_starts_at_one(self) -> None: + text = 'first\nsecond\n' + result = _format_lines(text.splitlines(keepends=True), 0, 10) + assert ' 1\tfirst\n' in result + assert ' 0\t' not in result + + async def test_format_lines_offset_line_numbering(self) -> None: + text = 'a\nb\nc\n' + result = _format_lines(text.splitlines(keepends=True), 1, 2) + assert ' 2\tb\n' in result + assert ' 3\tc\n' in result + + async def test_is_binary_exactly_at_sample_boundary(self) -> None: + # Null byte at position 8191 (index 8191, within first 8192 bytes) + data = b'x' * 8191 + b'\x00' + assert _is_binary(data) is True + # Null byte at position 8192 (outside the sample) + data2 = b'x' * 8192 + b'\x00' + assert _is_binary(data2) is False + + async def test_content_hash_returns_exactly_12_chars(self) -> None: + h = _content_hash('test content') + assert len(h) == 12 + # Verify it's hex characters + assert all(c in '0123456789abcdef' for c in h) + + async def test_write_file_with_hash_on_new_file(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + """When a file doesn't exist, expected_hash should be ignored and the write should succeed.""" + result = await toolset.write_file('brand_new.txt', 'new content\n', expected_hash='any_hash_val') + assert 'Wrote' in result + assert (fs_root / 'brand_new.txt').read_text() == 'new content\n' + + async def test_edit_file_single_match_succeeds(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + (fs_root / 'unique.txt').write_text('unique text here\n') + result = await toolset.edit_file('unique.txt', 'unique text', 'replaced text') + assert 'Edited' in result + assert (fs_root / 'unique.txt').read_text() == 'replaced text here\n' + + async def test_edit_file_zero_matches_raises(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='old_text not found'): + await toolset.edit_file('hello.txt', 'DEFINITELY NOT IN FILE', 'x') + + async def test_search_truncation_stops_after_limit(self, fs_root: Path) -> None: + # Create many files with 1 match each so truncation is per-file + for i in range(10): + (fs_root / f'searchable{i}.txt').write_text(f'match_this_{i}\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=5, + max_find_results=1000, + ) + result = await ts.search_files('match_this') + lines = result.strip().split('\n') + # Truncation check is after each file, so 5 matches + truncation msg + # Ensure we don't get all 10 matches + match_lines = [ln for ln in lines if ln.startswith('searchable')] + assert len(match_lines) <= 5 + assert 'truncated at 5 matches' in lines[-1] + + async def test_find_truncation_stops_after_limit(self, fs_root: Path) -> None: + for i in range(10): + (fs_root / f'findme{i:02d}.dat').write_text(f'{i}\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=3, + ) + result = await ts.find_files('*.dat') + lines = result.strip().split('\n') + # Should have exactly 4 lines: 3 matches + 1 truncation message + assert len(lines) == 4 + assert 'truncated at 3 matches' in lines[-1] + + async def test_read_file_default_limit_used(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + # Create file with more lines than we'd see with limit=0 + (fs_root / 'big.txt').write_text('\n'.join(f'line{i}' for i in range(100)) + '\n') + result = await toolset.read_file('big.txt') + # All 100 lines should be present since max_read_lines is 2000 + assert 'line99' in result + + async def test_list_directory_with_files_not_empty(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('subdir') + assert result != '(empty directory)' + assert 'nested.py' in result + + async def test_search_in_file_returns_only_that_file(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + # Both files contain 'Hello' / 'hello' but searching a specific file should only return from that file + (fs_root / 'other.txt').write_text('Hello from other\n') + result = await toolset.search_files('Hello', path='hello.txt') + assert 'hello.txt' in result + assert 'other.txt' not in result + + async def test_file_info_non_binary_shows_lines_and_hash(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('hello.txt') + assert 'lines: 1' in result + assert 'hash:' in result + assert 'binary: False' in result + + async def test_file_info_binary_no_lines_no_hash(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('binary.bin') + assert 'binary: True' in result + assert 'lines:' not in result + assert 'hash:' not in result + + async def test_safe_resolve_passes_write_flag(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + # Protected patterns block writes but allow reads + (fs_root / '.env.local').write_text('SECRET=x\n') + # Read should work (write=False internally) + result = await toolset.read_file('.env.local') + assert 'SECRET=x' in result + # Write should be blocked (write=True internally) + with pytest.raises(ModelRetry, match='protected'): + await toolset.write_file('.env.local', 'HACKED\n') + + async def test_format_lines_join_separator(self) -> None: + """Verify the result doesn't contain garbage between lines.""" + text = 'a\nb\nc\n' + result = _format_lines(text.splitlines(keepends=True), 0, 3) + # Lines should be directly adjacent (no separator between them) + assert ' 1\ta\n 2\tb\n 3\tc\n' in result + + async def test_format_lines_no_trailing_newline_preserves_content(self) -> None: + text = 'no newline' + result = _format_lines(text.splitlines(keepends=True), 0, 10) + # The content must still be present + assert 'no newline' in result + assert result.endswith('\n') + + async def test_read_file_hash_is_real_hash(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.read_file('hello.txt') + # The actual hash should be a hex string, not 'None' + assert 'hash:None' not in result + # Verify the hash matches what we'd compute + expected_hash = _content_hash('Hello, world!\n') + assert f'hash:{expected_hash}' in result + + async def test_read_file_non_ascii_content(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + """With invalid UTF-8 bytes, the tool should not crash — it should use replacement chars.""" + # Write raw bytes that are invalid UTF-8 + (fs_root / 'broken_utf8.txt').write_bytes(b'hello \xff\xfe world\n') + result = await toolset.read_file('broken_utf8.txt') + # Should not crash, content should contain replacement characters + assert 'hello' in result + assert 'world' in result + + async def test_read_file_default_offset_starts_at_first_line(self, toolset: FileSystemToolset[None]) -> None: + """The first line must be included when no offset is specified.""" + result = await toolset.read_file('multi.txt') + # First line must be present (line1) + assert ' 1\tline1' in result + # Verify line numbering starts at 1 + assert ' 0\t' not in result + + async def test_toolset_tool_names(self, toolset: FileSystemToolset[None]) -> None: + """Verify tools are registered with correct names.""" + tool_names = set(toolset.tools.keys()) + assert 'read_file' in tool_names + assert 'write_file' in tool_names + assert 'edit_file' in tool_names + assert 'list_directory' in tool_names + assert 'search_files' in tool_names + assert 'find_files' in tool_names + assert 'create_directory' in tool_names + assert 'file_info' in tool_names + + async def test_write_file_output_format(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + result = await toolset.write_file('fmt.txt', 'ab\ncd\n') + # Verify specific format: chars, lines, path, hash + assert 'Wrote 6 chars (2 lines) to fmt.txt.' in result + assert 'hash:' in result + # Verify hash is a real hex hash not None + assert 'hash:None' not in result + + async def test_edit_file_output_format(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + result = await toolset.edit_file('hello.txt', 'Hello, world!', 'Hi') + assert result.startswith('Edited hello.txt.') + assert 'hash:' in result + assert 'hash:None' not in result + + def test_format_lines_no_double_trailing_newline(self) -> None: + """Text that already ends with newline must NOT get a second one appended.""" + text = 'hello\n' + result = _format_lines(text.splitlines(keepends=True), 0, 10) + # Exact match: no trailing double newline + assert result == ' 1\thello\n' + + def test_safe_resolve_write_default_is_false(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + """Protected files should be readable via _safe_resolve's default (write=False).""" + (fs_root / '.env.local').write_text('SECRET=x\n') + # _safe_resolve without write= uses default write=False → read is allowed + resolved = toolset._safe_resolve('.env.local') + assert resolved.name == '.env.local' + # But with write=True, it should raise. `_safe_resolve` is an internal + # helper, so it raises the native PermissionError; the `ModelRetry` + # conversion happens in the public tool methods that wrap it. + with pytest.raises(PermissionError, match='protected'): + toolset._safe_resolve('.env.local', write=True) + + async def test_list_directory_exact_size(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('.') + # hello.txt has 'Hello, world!\n' = 14 bytes + assert '14 bytes' in result + + async def test_list_directory_no_garbage_separator(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.list_directory('.') + assert 'XX' not in result + + async def test_list_directory_error_message(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='Not a directory'): + await toolset.list_directory('hello.txt') + + async def test_find_files_error_message(self, toolset: FileSystemToolset[None]) -> None: + with pytest.raises(ModelRetry, match='Not a directory'): + await toolset.find_files('*.txt', path='hello.txt') + + async def test_find_files_no_suffix_on_files(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('*') + for line in result.splitlines(): + if not line.endswith('/'): + assert 'XXXX' not in line + + async def test_find_files_no_garbage_separator(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.find_files('*.txt') + assert 'XX' not in result + + async def test_search_files_no_garbage_separator(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.search_files(r'line\d') + assert 'XX' not in result + + async def test_file_info_exact_size(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('hello.txt') + assert '14 bytes' in result + + async def test_file_info_no_garbage_separator(self, toolset: FileSystemToolset[None]) -> None: + result = await toolset.file_info('hello.txt') + assert 'XX' not in result + + async def test_search_with_invalid_utf8_file(self, toolset: FileSystemToolset[None], fs_root: Path) -> None: + """A file with invalid UTF-8 (but no null bytes = not binary) should be searchable.""" + # Write a file with invalid UTF-8 but no null bytes (not detected as binary) + (fs_root / 'bad_encoding.txt').write_bytes(b'marker_text \xff\xfe end\n') + result = await toolset.search_files('marker_text') + # Should find the file even with broken encoding + assert 'bad_encoding.txt' in result + + async def test_search_binary_skip_does_not_stop_iteration(self, toolset: FileSystemToolset[None]) -> None: + """A binary file must be skipped, but subsequent text files must still be searched.""" + # binary.bin exists in the fixture and comes before 'hello.txt' alphabetically + result = await toolset.search_files('Hello') + # hello.txt must still be found (binary.bin didn't break the loop) + assert 'hello.txt' in result + + async def test_find_hidden_skip_does_not_stop_iteration(self, toolset: FileSystemToolset[None]) -> None: + """Hidden files must be skipped, but subsequent visible files must still appear.""" + # .hidden comes before hello.txt alphabetically — skipping must not break the loop + result = await toolset.find_files('*') + assert 'hello.txt' in result + assert 'multi.txt' in result + + +class TestFileSystemCapability: + def test_default_construction(self) -> None: + fs = FileSystem() + assert fs.root_dir == '.' + assert fs.max_read_lines == 2000 + + def test_custom_construction(self, tmp_path: Path) -> None: + fs = FileSystem( + root_dir=tmp_path, + allowed_patterns=['*.py'], + denied_patterns=['test_*'], + max_read_lines=500, + ) + assert fs.max_read_lines == 500 + + def test_get_toolset_returns_toolset(self, tmp_path: Path) -> None: + fs = FileSystem(root_dir=tmp_path) + toolset = fs.get_toolset() + assert isinstance(toolset, FileSystemToolset) + + def test_protected_defaults(self) -> None: + fs = FileSystem() + assert '.git/*' in fs.protected_patterns + assert '.env' in fs.protected_patterns + + def test_non_positive_max_read_lines_rejected(self) -> None: + with pytest.raises(ValueError, match='max_read_lines must be a positive integer'): + FileSystem(max_read_lines=0) + with pytest.raises(ValueError, match='max_read_lines must be a positive integer'): + FileSystem(max_read_lines=-1) + + def test_non_positive_max_search_results_rejected(self) -> None: + with pytest.raises(ValueError, match='max_search_results must be a positive integer'): + FileSystem(max_search_results=0) + + def test_non_positive_max_find_results_rejected(self) -> None: + with pytest.raises(ValueError, match='max_find_results must be a positive integer'): + FileSystem(max_find_results=-1) + + def test_non_integer_max_read_lines_rejected(self) -> None: + # Runtime validation: dataclass annotations are advisory, so a string + # slipped in from a config must be rejected, not propagated. + with pytest.raises(ValueError, match='max_read_lines must be a positive integer'): + FileSystem(max_read_lines='1000') # type: ignore[arg-type] + + @pytest.mark.anyio(backends=['asyncio']) + async def test_agent_integration(self, tmp_path: Path, anyio_backend: object) -> None: + if str(anyio_backend) != 'asyncio': + pytest.skip('Agent.run requires asyncio event loop') + (tmp_path / 'test.txt').write_text('hello agent\n') + model = TestModel(custom_output_text='done', call_tools=[]) + agent: Agent[None, str] = Agent(model, capabilities=[FileSystem(root_dir=tmp_path)]) + result = await agent.run('read test.txt') + assert result.output == 'done' + + +class TestPatternCanonicalization: + """Sec#3: patterns match the canonical path, and a leading `**/` also + covers the repository root.""" + + async def test_denied_pattern_not_bypassed_by_dot_segment(self, fs_root: Path) -> None: + (fs_root / 'config').mkdir() + (fs_root / 'config' / 'secret.txt').write_text('token\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=['config/secret.txt'], + protected_patterns=[], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + # A './' segment must not slip the file past its deny rule. + with pytest.raises(ModelRetry, match='denied'): + await ts.read_file('config/./secret.txt') + + async def test_root_level_secrets_hidden_from_search(self, fs_root: Path) -> None: + (fs_root / 'secrets.yaml').write_text('api: PRIVATE KEY material\n') + ts = FileSystemToolset( + root_dir=fs_root, + allowed_patterns=[], + denied_patterns=[], + protected_patterns=['**/secrets*'], + max_read_lines=2000, + max_search_results=1000, + max_find_results=1000, + ) + # `**/secrets*` must protect a root-level secrets file, not just nested ones. + result = await ts.search_files('PRIVATE KEY') + assert 'secrets.yaml' not in result diff --git a/tests/shell/__init__.py b/tests/shell/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/shell/test_shell.py b/tests/shell/test_shell.py new file mode 100644 index 0000000..785d0e7 --- /dev/null +++ b/tests/shell/test_shell.py @@ -0,0 +1,1416 @@ +"""Tests for the Shell capability and ShellToolset.""" + +from __future__ import annotations + +import os +import shlex +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import anyio +import pytest +from pydantic_ai import Agent, RunContext +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.models.test import TestModel +from pydantic_ai.usage import RunUsage + +from pydantic_ai_harness.shell import Shell +from pydantic_ai_harness.shell._toolset import ( + ShellToolset, + _is_interactive_command, +) + + +def _run_context() -> RunContext[None]: + """Minimal `RunContext` for invoking `for_run` directly in tests.""" + return RunContext[None]( + deps=None, + model=TestModel(), + usage=RunUsage(), + prompt=None, + messages=[], + run_step=0, + ) + + +def _parse_command_id(result: str) -> str: + assert 'ID: ' in result, f'Expected "ID: " in result: {result!r}' + return result.split('ID: ')[1].strip() + + +class TestIsInteractiveCommand: + def test_vi(self) -> None: + assert _is_interactive_command('vi file.txt') is True + + def test_vim(self) -> None: + assert _is_interactive_command('vim file.txt') is True + + def test_nano(self) -> None: + assert _is_interactive_command('nano file.txt') is True + + def test_less(self) -> None: + assert _is_interactive_command('less file.txt') is True + + def test_top(self) -> None: + assert _is_interactive_command('top') is True + + def test_sudo(self) -> None: + assert _is_interactive_command('sudo rm -rf /') is True + + def test_ssh(self) -> None: + assert _is_interactive_command('ssh host') is True + + def test_regular_command(self) -> None: + assert _is_interactive_command('ls -la') is False + + def test_echo(self) -> None: + assert _is_interactive_command('echo hello') is False + + def test_grep(self) -> None: + assert _is_interactive_command('grep pattern file') is False + + def test_emacs(self) -> None: + assert _is_interactive_command('emacs file.txt') is True + + def test_man(self) -> None: + assert _is_interactive_command('man ls') is True + + def test_htop(self) -> None: + assert _is_interactive_command('htop') is True + + def test_telnet(self) -> None: + assert _is_interactive_command('telnet localhost 80') is True + + def test_ftp(self) -> None: + assert _is_interactive_command('ftp host') is True + + def test_passwd(self) -> None: + assert _is_interactive_command('passwd') is True + + def test_more(self) -> None: + assert _is_interactive_command('more file.txt') is True + + def test_not_prefix_match(self) -> None: + assert _is_interactive_command('view file.txt') is False + assert _is_interactive_command('vishnu') is False + + def test_leading_spaces(self) -> None: + assert _is_interactive_command(' vi file.txt') is True + assert _is_interactive_command(' sudo rm') is True + + +@pytest.fixture +def shell_dir(tmp_path: Path) -> Path: + (tmp_path / 'test.txt').write_text('hello\n') + (tmp_path / 'subdir').mkdir() + (tmp_path / 'subdir' / 'nested.txt').write_text('nested\n') + return tmp_path + + +@pytest.fixture +def toolset(shell_dir: Path) -> ShellToolset[None]: + return ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=['rm', 'rmdir'], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + + +@pytest.fixture +def persist_toolset(shell_dir: Path) -> ShellToolset[None]: + return ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + + +class TestCommandValidation: + async def test_denied_command_blocked(self, toolset: ShellToolset[None]) -> None: + with pytest.raises(PermissionError, match="'rm' is denied"): + toolset._check_command('rm -rf /') + + async def test_allowed_command_permitted(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=['echo', 'cat'], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + ts._check_command('echo hello') + ts._check_command('cat file.txt') + + async def test_allowed_blocks_non_matching(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=['echo'], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + with pytest.raises(PermissionError, match='not in the allowed list'): + ts._check_command('cat file.txt') + + async def test_both_allow_and_deny_raises(self, shell_dir: Path) -> None: + with pytest.raises(ValueError, match='Specify allowed_commands or denied_commands'): + ShellToolset( + cwd=shell_dir, + allowed_commands=['echo'], + denied_commands=['rm'], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + + async def test_interactive_blocked_by_default(self, toolset: ShellToolset[None]) -> None: + with pytest.raises(PermissionError, match='Interactive commands'): + toolset._check_command('vim file.txt') + + async def test_interactive_allowed_when_enabled(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=True, + ) + ts._check_command('vim file.txt') + + async def test_denied_operator_blocked(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=['>', '>>'], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + with pytest.raises(PermissionError, match="'>' is not allowed"): + ts._check_command('echo hello > file.txt') + + async def test_denied_operator_passes_when_not_present(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=['>', '>>'], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + ts._check_command('echo hello') + + async def test_unparseable_command_allowed(self, toolset: ShellToolset[None]) -> None: + toolset._check_command("echo 'unterminated") + + async def test_empty_command_allowed(self, toolset: ShellToolset[None]) -> None: + toolset._check_command('') + + async def test_denied_operator_substring_match(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=['>>'], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + with pytest.raises(PermissionError, match="'>>' is not allowed"): + ts._check_command('echo hello >> file.txt') + + async def test_shlex_error_returns_early(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=['rm'], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + ts._check_command("echo 'unterminated") + + async def test_empty_tokens(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=['echo'], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + ts._check_command('') + + def test_first_denied_operator_match(self, toolset: ShellToolset[None]) -> None: + ts = ShellToolset( + cwd=Path('/tmp'), + allowed_commands=[], + denied_commands=[], + denied_operators=['|', '>'], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + assert ts._first_denied_operator('echo hi | cat') == '|' + + def test_first_denied_operator_no_match(self, toolset: ShellToolset[None]) -> None: + ts = ShellToolset( + cwd=Path('/tmp'), + allowed_commands=[], + denied_commands=[], + denied_operators=['|', '>'], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + assert ts._first_denied_operator('echo hello') is None + + def test_first_denied_operator_empty_list(self, toolset: ShellToolset[None]) -> None: + assert toolset._first_denied_operator('echo hi | cat') is None + + +class TestTruncation: + def test_within_limit(self, toolset: ShellToolset[None]) -> None: + assert toolset._truncate('short') == 'short' + + def test_at_limit(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=10, + persist_cwd=False, + allow_interactive=False, + ) + result = ts._truncate('x' * 10) + assert result == 'x' * 10 + + def test_over_limit(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=10, + persist_cwd=False, + allow_interactive=False, + ) + result = ts._truncate('x' * 20) + assert result.endswith('x' * 10) + assert 'truncated, showing last 10 chars' in result + + def test_exactly_at_limit_not_truncated(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=10, + persist_cwd=False, + allow_interactive=False, + ) + result = ts._truncate('x' * 10) + assert result == 'x' * 10 + assert 'truncated' not in result + + def test_one_over_limit_truncated(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=10, + persist_cwd=False, + allow_interactive=False, + ) + result = ts._truncate('x' * 11) + assert result.endswith('x' * 10) + assert 'truncated, showing last 10 chars' in result + + def test_keeps_tail_not_head(self, shell_dir: Path) -> None: + """The tail (where errors and the [stderr] section land) is preserved.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=20, + persist_cwd=False, + allow_interactive=False, + ) + text = 'HEAD' + 'x' * 100 + 'TAIL_ERROR' + result = ts._truncate(text) + assert result.endswith('TAIL_ERROR') + assert 'HEAD' not in result + assert 'truncated' in result + + def test_truncation_marker_wording(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=10, + persist_cwd=False, + allow_interactive=False, + ) + result = ts._truncate('x' * 20) + assert 'output truncated, showing last 10 chars' in result + + +class TestCwdCapture: + """The persistent-cwd mechanism records `pwd` out-of-band via a private temp + file, so command output can never spoof the tracked directory.""" + + def test_capture_disabled_returns_command_unchanged(self, toolset: ShellToolset[None]) -> None: + wrapped, cwd_file = toolset._build_cwd_capture('echo hi') + assert wrapped == 'echo hi' + assert cwd_file is None + + def test_capture_records_pwd_out_of_band(self, persist_toolset: ShellToolset[None]) -> None: + wrapped, cwd_file = persist_toolset._build_cwd_capture('echo hi') + assert cwd_file is not None + try: + # pwd is redirected to the private temp file, never echoed to stdout + assert f'pwd > {shlex.quote(str(cwd_file))}' in wrapped + assert wrapped.startswith('echo hi') + finally: + cwd_file.unlink(missing_ok=True) + + def test_apply_valid_dir_updates_cwd( + self, persist_toolset: ShellToolset[None], shell_dir: Path, tmp_path: Path + ) -> None: + capture = tmp_path / 'cwd' + capture.write_text(f'{shell_dir / "subdir"}\n') + persist_toolset._apply_captured_cwd(capture) + assert persist_toolset._cwd == shell_dir / 'subdir' + + def test_apply_empty_file_keeps_cwd(self, persist_toolset: ShellToolset[None], tmp_path: Path) -> None: + original = persist_toolset._cwd + capture = tmp_path / 'cwd' + capture.write_text('') + persist_toolset._apply_captured_cwd(capture) + assert persist_toolset._cwd == original + + def test_apply_non_dir_keeps_cwd(self, persist_toolset: ShellToolset[None], tmp_path: Path) -> None: + original = persist_toolset._cwd + capture = tmp_path / 'cwd' + capture.write_text(str(tmp_path / 'does_not_exist')) + persist_toolset._apply_captured_cwd(capture) + assert persist_toolset._cwd == original + + +class TestForRunIsolation: + """B3: `get_toolset` builds one shared instance at agent construction, so + `for_run` must hand each run a fresh copy — otherwise concurrent runs share + `_cwd`/`_background` and corrupt each other.""" + + async def test_for_run_returns_fresh_instance(self, persist_toolset: ShellToolset[None]) -> None: + run1 = await persist_toolset.for_run(_run_context()) + run2 = await persist_toolset.for_run(_run_context()) + assert run1 is not persist_toolset + assert run2 is not run1 + + async def test_persist_cwd_isolated_across_runs(self, persist_toolset: ShellToolset[None], shell_dir: Path) -> None: + run1 = await persist_toolset.for_run(_run_context()) + assert isinstance(run1, ShellToolset) + await run1.run_command('cd subdir') + assert run1._cwd == shell_dir / 'subdir' + # A second run must start back at the configured root, not inherit run1's cd. + run2 = await persist_toolset.for_run(_run_context()) + assert isinstance(run2, ShellToolset) + assert run2._cwd == shell_dir + + +class TestPersistCwdHardening: + """B4: regression tests for the old stdout-sentinel footguns — a command's + output spoofing the cwd, and `;` silently disabling tracking.""" + + async def test_cd_persists_even_with_semicolon(self, persist_toolset: ShellToolset[None]) -> None: + # The old mechanism skipped tracking whenever ';' appeared, silently + # dropping a real `cd`. The out-of-band capture records it regardless. + await persist_toolset.run_command('cd subdir ; true') + result = await persist_toolset.run_command('pwd') + assert 'subdir' in result + + async def test_output_cannot_spoof_cwd(self, persist_toolset: ShellToolset[None], shell_dir: Path) -> None: + # The old mechanism parsed cwd from stdout, so a command printing the + # sentinel string could redirect the tracked cwd with no real cd. + spoof = f'true ; echo __HARNESS_PWD__{shell_dir / "subdir"}' + await persist_toolset.run_command(spoof) + assert persist_toolset._cwd == shell_dir + + +class TestRunCommand: + async def test_basic_echo(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('echo hello') + assert '[stdout]' in result + assert 'hello' in result + + async def test_stderr_output(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('echo error >&2') + assert '[stderr]' in result + assert 'error' in result + + async def test_mixed_output(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('echo out && echo err >&2') + assert '[stdout]' in result + assert '[stderr]' in result + + async def test_exit_code_reported(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('exit 42') + assert '[exit code: 42]' in result + + async def test_exit_code_zero_not_shown(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('echo ok') + assert 'exit code' not in result + + async def test_no_output(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('true') + assert result == '(no output)' + + async def test_output_truncation(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command(f'{sys.executable} -c "print(\'x\' * 200)"') + assert 'truncated, showing last 50 chars' in result + + async def test_persist_cwd(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + await ts.run_command('cd subdir') + result = await ts.run_command('pwd') + assert 'subdir' in result + + async def test_persist_cwd_only_on_success(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + original = ts._cwd + await ts.run_command('cd nonexistent_dir_xyz && false') + assert ts._cwd == original + + async def test_denied_command_in_run(self, toolset: ShellToolset[None]) -> None: + # B2: a denied command is model-correctable, so it surfaces as ModelRetry + # (which pyai feeds back to the model) rather than aborting the run. + with pytest.raises(ModelRetry, match="'rm' is denied"): + await toolset.run_command('rm -rf /') + + async def test_cwd_used(self, toolset: ShellToolset[None], shell_dir: Path) -> None: + result = await toolset.run_command('cat test.txt') + assert 'hello' in result + + async def test_multiline_output(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command(f'{sys.executable} -c "print(\'a\\nb\\nc\\n\')"') + assert '[stdout]' in result + + async def test_timeout_reports_value(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=0.5, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command('sleep 10') + assert 'timed out after 0.5s' in result + + async def test_custom_timeout_overrides_default(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=30.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command('sleep 10', timeout_seconds=0.5) + assert 'timed out after 0.5s' in result + + async def test_persist_cwd_disabled_no_update(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + original = ts._cwd + await ts.run_command('cd subdir') + assert ts._cwd == original + + async def test_nonzero_exit_shows_code(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('exit 1') + assert '[exit code: 1]' in result + + async def test_stdout_stderr_separated_by_newline(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command('echo out && echo err >&2') + assert '[stdout]\nout\n\n[stderr]\nerr' in result + + async def test_non_ascii_stdout(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command( + f'{sys.executable} -c "import sys; sys.stdout.buffer.write(b\'hello \\xff\\xfe world\\n\')"' + ) + assert 'hello' in result + + async def test_non_ascii_stderr(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command( + f'{sys.executable} -c "import sys; sys.stderr.buffer.write(b\'err \\xff\\xfe msg\\n\')"' + ) + assert 'err' in result + + async def test_stdout_chunk_join(self, toolset: ShellToolset[None]) -> None: + result = await toolset.run_command(f"{sys.executable} -c \"print('A' * 100 + 'B' * 100)\"") + assert 'A' * 100 + 'B' * 100 in result + + async def test_exit_code_fallback_to_zero(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + result = await ts.run_command('echo ok') + assert 'exit code' not in result + + async def test_error_message_content(self, shell_dir: Path) -> None: + with pytest.raises(ValueError, match='^Specify allowed_commands or denied_commands, not both\\.$'): + ShellToolset( + cwd=shell_dir, + allowed_commands=['echo'], + denied_commands=['rm'], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + + async def test_stdout_chunks_joined_cleanly(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=30.0, + max_output_chars=500_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command("printf '%05000d\\n' $(seq 1 100)") + assert 'XXXX' not in result + + async def test_stderr_chunks_joined_cleanly(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=30.0, + max_output_chars=500_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command("printf '%0500d\\n' $(seq 1 100) >&2") + assert 'XXXX' not in result + + async def test_persist_cwd_updates_after_cd(self, shell_dir: Path) -> None: + """CWD should update to the actual directory after a successful cd.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + await ts.run_command('cd subdir') + assert ts._cwd == (shell_dir / 'subdir') + + async def test_persist_cwd_not_updated_on_failure(self, shell_dir: Path) -> None: + """CWD should not update if command fails (exit code non-zero).""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + original = ts._cwd + await ts.run_command('false') + assert ts._cwd == original + + +class TestProcessGroupKill: + async def test_timeout_kills_subprocess_tree(self, shell_dir: Path) -> None: + """On timeout, the entire process group should be killed.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=0.5, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command('bash -c "sleep 100 & sleep 100"') + assert 'timed out' in result + + async def test_timeout_with_output_before_timeout(self, shell_dir: Path) -> None: + """Output produced before timeout should still result in timeout message.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=0.5, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command('echo before_timeout && sleep 100') + assert 'timed out' in result + + async def test_start_new_session_used(self, shell_dir: Path) -> None: + """Verify the child is in a different process group from the parent.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + parent_pgrp = os.getpgrp() + result = await ts.run_command(f'{sys.executable} -c "import os; print(os.getpgrp() != {parent_pgrp})"') + assert 'True' in result + + +class TestBackgroundCommands: + async def test_start_command_returns_id(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.start_command('sleep 100') + assert 'ID:' in result + assert 'Started background command' in result + command_id = _parse_command_id(result) + await ts.stop_command(command_id) + + async def test_check_unknown_id(self, toolset: ShellToolset[None]) -> None: + result = await toolset.check_command('nonexistent_id') + assert 'unknown command ID' in result + + async def test_stop_unknown_id(self, toolset: ShellToolset[None]) -> None: + result = await toolset.stop_command('nonexistent_id') + assert 'unknown command ID' in result + + async def test_start_and_stop(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('echo hello_bg') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + stop_result = await ts.stop_command(command_id) + assert 'stopped' in stop_result + assert 'hello_bg' in stop_result + + async def test_start_and_check_running(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('sleep 100') + command_id = _parse_command_id(start_result) + + check_result = await ts.check_command(command_id) + assert 'running' in check_result + + await ts.stop_command(command_id) + + async def test_start_and_check_finished(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('echo done_quick') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + check_result = await ts.check_command(command_id) + assert 'finished' in check_result + assert 'done_quick' in check_result + + await ts.stop_command(command_id) + + async def test_start_denied_command_raises(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=['rm'], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + with pytest.raises(ModelRetry, match="'rm' is denied"): + await ts.start_command('rm -rf /') + + async def test_stop_captures_stderr(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('echo err_bg >&2') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + stop_result = await ts.stop_command(command_id) + assert 'err_bg' in stop_result + + async def test_stop_no_output(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('true') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + stop_result = await ts.stop_command(command_id) + assert '(no output)' in stop_result + + async def test_check_no_output_yet(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('sleep 100') + command_id = _parse_command_id(start_result) + + check_result = await ts.check_command(command_id) + assert 'no output yet' in check_result + + await ts.stop_command(command_id) + + async def test_check_command_captures_stderr(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('echo err_check >&2') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + check_result = await ts.check_command(command_id) + assert '[stderr]' in check_result + assert 'err_check' in check_result + + await ts.stop_command(command_id) + + async def test_start_command_uses_cwd(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('pwd') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + stop_result = await ts.stop_command(command_id) + assert str(shell_dir) in stop_result + + async def test_stop_removes_from_registry(self, shell_dir: Path) -> None: + """After stop, the command_id should no longer be known.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + start_result = await ts.start_command('true') + command_id = _parse_command_id(start_result) + + await anyio.sleep(0.5) + + await ts.stop_command(command_id) + + # Should now be unknown + check_result = await ts.check_command(command_id) + assert 'unknown command ID' in check_result + + async def test_start_command_cleans_temp_files_on_failure(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + with patch('anyio.open_process', side_effect=OSError('spawn failed')): + with pytest.raises(OSError, match='spawn failed'): + await ts.start_command('echo hi') + assert not ts._background + + async def test_aexit_terminates_background_processes(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.start_command('sleep 300') + command_id = _parse_command_id(result) + bg = ts._background[command_id] + stdout_path = Path(bg.stdout_path) + stderr_path = Path(bg.stderr_path) + assert stdout_path.exists() + assert stderr_path.exists() + + await ts.__aexit__(None, None, None) + + assert not ts._background + assert not stdout_path.exists() + assert not stderr_path.exists() + + async def test_aexit_noop_when_no_background(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + await ts.__aexit__(None, None, None) + assert not ts._background + + async def test_aexit_cleans_already_finished_process(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.start_command('echo done') + command_id = _parse_command_id(result) + await anyio.sleep(0.5) + # Mark as finished via check_command + await ts.check_command(command_id) + bg = ts._background[command_id] + assert bg.finished + + await ts.__aexit__(None, None, None) + assert not ts._background + + +class TestEdgeCases: + async def test_toolset_tool_names(self, toolset: ShellToolset[None]) -> None: + tool_names = list(toolset.tools.keys()) + assert 'run_command' in tool_names + assert 'start_command' in tool_names + assert 'check_command' in tool_names + assert 'stop_command' in tool_names + + async def test_run_command_uses_actual_cwd(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + result = await ts.run_command('pwd') + assert str(shell_dir) in result + + async def test_persist_cwd_requires_all_three_conditions(self, shell_dir: Path) -> None: + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=True, + allow_interactive=False, + ) + # Successful echo — sentinel shows same dir, cwd should remain valid + await ts.run_command('echo hi') + assert ts._cwd.is_dir() + + +class TestShellCapability: + def test_default_construction(self) -> None: + shell = Shell() + assert shell.cwd == '.' + assert shell.default_timeout == 30.0 + assert 'rm' in shell.denied_commands + + def test_custom_construction(self) -> None: + shell = Shell( + cwd='/tmp', + allowed_commands=['echo', 'cat'], + denied_commands=[], + default_timeout=60.0, + ) + assert shell.default_timeout == 60.0 + + def test_get_toolset_returns_toolset(self, tmp_path: Path) -> None: + shell = Shell(cwd=tmp_path) + toolset = shell.get_toolset() + assert isinstance(toolset, ShellToolset) + + def test_default_denied_commands(self) -> None: + shell = Shell() + assert 'rm' in shell.denied_commands + assert 'dd' in shell.denied_commands + assert 'shutdown' in shell.denied_commands + + @pytest.mark.anyio(backends=['asyncio']) + async def test_agent_integration(self, tmp_path: Path) -> None: + import sniffio + + if sniffio.current_async_library() != 'asyncio': # pragma: no cover + pytest.skip('Agent.run() requires asyncio') + model = TestModel(custom_output_text='done', call_tools=[]) + agent: Agent[None, str] = Agent(model, capabilities=[Shell(cwd=tmp_path)]) + result = await agent.run('run echo hello') + assert result.output == 'done' + + +class TestKillProcessGroupEdgeCases: + async def test_sigterm_raises_process_lookup_error(self, tmp_path: Path) -> None: + """When SIGTERM raises ProcessLookupError, method returns without SIGKILL.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + proc = MagicMock() + proc.pid = 99999 + with patch('os.killpg', side_effect=ProcessLookupError): + await ts._kill_process_group(proc) + # No exception raised, method returned early + + async def test_sigkill_escalation(self, tmp_path: Path) -> None: + """When process doesn't exit within grace period, SIGKILL is sent.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + proc = MagicMock() + proc.pid = 99999 + + # Make proc.wait() never complete (simulates process ignoring SIGTERM) + async def never_return() -> None: + await anyio.sleep(999) + + proc.wait = never_return + + import signal + + kill_calls: list[tuple[int, int]] = [] + + def fake_killpg(pgid: int, sig: int) -> None: + kill_calls.append((pgid, sig)) + + with ( + patch('os.killpg', side_effect=fake_killpg), + patch('os.getpgid', return_value=12345), + patch('pydantic_ai_harness.shell._toolset._KILL_GRACE_PERIOD', 0.01), + ): + await ts._kill_process_group(proc) + + assert len(kill_calls) == 2 + assert kill_calls[0][1] == signal.SIGTERM + assert kill_calls[1][1] == signal.SIGKILL + + async def test_sigkill_raises_process_lookup_error(self, tmp_path: Path) -> None: + """When SIGKILL raises ProcessLookupError (process exited between SIGTERM and SIGKILL).""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + proc = MagicMock() + proc.pid = 99999 + + async def never_return() -> None: + await anyio.sleep(999) + + proc.wait = never_return + + import signal + + call_count = 0 + + def fake_killpg(pgid: int, sig: int) -> None: + nonlocal call_count + call_count += 1 + if sig == signal.SIGKILL: + raise ProcessLookupError + + with ( + patch('os.killpg', side_effect=fake_killpg), + patch('os.getpgid', return_value=12345), + patch('pydantic_ai_harness.shell._toolset._KILL_GRACE_PERIOD', 0.01), + ): + await ts._kill_process_group(proc) + + assert call_count == 2 + + +class TestDrainWithTimeoutEdgeCases: + async def test_stdout_closed_resource_error(self, tmp_path: Path) -> None: + """ClosedResourceError on stdout is caught silently after yielding data.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + proc = MagicMock() + + # Yield one chunk then raise ClosedResourceError + class FailingStream: + def __init__(self) -> None: + self._yielded = False + + def __aiter__(self) -> FailingStream: + return self + + async def __anext__(self) -> bytes: + if not self._yielded: + self._yielded = True + return b'partial' + raise anyio.ClosedResourceError + + proc.stdout = FailingStream() + proc.stderr = None + + stdout_chunks: list[bytes] = [] + stderr_chunks: list[bytes] = [] + await ts._drain_with_timeout(stdout_chunks, stderr_chunks, proc) + assert stdout_chunks == [b'partial'] + + async def test_stderr_broken_resource_error(self, tmp_path: Path) -> None: + """BrokenResourceError on stderr is caught silently after yielding data.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + proc = MagicMock() + proc.stdout = None + + class FailingStream: + def __init__(self) -> None: + self._yielded = False + + def __aiter__(self) -> FailingStream: + return self + + async def __anext__(self) -> bytes: + if not self._yielded: + self._yielded = True + return b'partial' + raise anyio.BrokenResourceError + + proc.stderr = FailingStream() + + stdout_chunks: list[bytes] = [] + stderr_chunks: list[bytes] = [] + await ts._drain_with_timeout(stdout_chunks, stderr_chunks, proc) + assert stderr_chunks == [b'partial'] + + +class TestReadBgOutputEdgeCases: + def test_stdout_oserror(self, tmp_path: Path) -> None: + """OSError reading stdout file returns empty string.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + bg = MagicMock() + bg.stdout_path = '/nonexistent/path/stdout' + bg.stderr_path = '/nonexistent/path/stderr' + + stdout, stderr = ts._read_bg_output(bg) + assert stdout == '' + assert stderr == '' + + def test_stderr_oserror_only(self, tmp_path: Path) -> None: + """OSError reading stderr file only, stdout succeeds.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + # Create a valid stdout file but invalid stderr path + stdout_file = tmp_path / 'stdout.txt' + stdout_file.write_text('hello') + + bg = MagicMock() + bg.stdout_path = str(stdout_file) + bg.stderr_path = '/nonexistent/path/stderr' + + stdout, stderr = ts._read_bg_output(bg) + assert stdout == 'hello' + assert stderr == '' + + +class TestCleanupBgFilesEdgeCases: + def test_unlink_oserror(self, tmp_path: Path) -> None: + """OSError on unlink is caught silently.""" + ts = ShellToolset( + cwd=tmp_path, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=5.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + bg = MagicMock() + bg.stdout_path = '/nonexistent/path/stdout' + bg.stderr_path = '/nonexistent/path/stderr' + + # Should not raise + ts._cleanup_bg_files(bg) + + +class TestStopCommandAlreadyFinished: + async def test_stop_already_finished_process(self, shell_dir: Path) -> None: + """stop_command on an already-finished process skips kill.""" + ts = ShellToolset( + cwd=shell_dir, + allowed_commands=[], + denied_commands=[], + denied_operators=[], + default_timeout=10.0, + max_output_chars=50_000, + persist_cwd=False, + allow_interactive=False, + ) + # Start a command that finishes immediately + start_result = await ts.start_command('echo done') + command_id = _parse_command_id(start_result) + + # Wait for the process to finish + await anyio.sleep(0.5) + + # Manually mark as finished with exit_code = None (simulates edge case + # where finished is True but exit_code was never captured) + bg = ts._background[command_id] + bg.finished = True + bg.exit_code = None + + # stop_command should skip the kill branch and handle None exit_code + result = await ts.stop_command(command_id) + assert '[stopped:' in result + assert '[exit code:' not in result diff --git a/tests/test_placeholder.py b/tests/test_placeholder.py index a66e1c6..971604a 100644 --- a/tests/test_placeholder.py +++ b/tests/test_placeholder.py @@ -1,5 +1,7 @@ +import inspect from pathlib import Path +import pytest from pydantic_ai import Agent from pydantic_ai.models.test import TestModel @@ -11,6 +13,25 @@ def test_import(): assert isinstance(pydantic_ai_harness.__all__, list) +def test_lazy_import_filesystem(): + from pydantic_ai_harness import FileSystem + + assert inspect.isclass(FileSystem) + assert hasattr(FileSystem, 'get_toolset') + + +def test_lazy_import_shell(): + from pydantic_ai_harness import Shell + + assert inspect.isclass(Shell) + assert hasattr(Shell, 'get_toolset') + + +def test_lazy_import_unknown(): + with pytest.raises(AttributeError, match='has no attribute'): + pydantic_ai_harness.__getattr__('Nonexistent') + + def test_test_model_fixture(test_model: TestModel): assert isinstance(test_model, TestModel) diff --git a/uv.lock b/uv.lock index 27b3bf5..1249a74 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,11 @@ version = 1 revision = 3 requires-python = ">=3.10" +resolution-markers = [ + "python_full_version >= '3.14'", + "python_full_version == '3.13.*'", + "python_full_version < '3.13'", +] [options]