diff --git a/src/strands_tools/apply_patch.py b/src/strands_tools/apply_patch.py new file mode 100644 index 00000000..aaff4e5e --- /dev/null +++ b/src/strands_tools/apply_patch.py @@ -0,0 +1,255 @@ +"""Hash-anchored file edits. + +Apply one or more edits to a file by anchoring each edit on a short content +hash of the lines being replaced. If the file shifts in unrelated places, the +hash for the targeted span still matches and the edit lands. If the targeted +span itself has changed, the hash mismatches and the edit fails cleanly +instead of corrupting the file. + +This is a port of the idea from the Dirac TerminalBench writeup ("hash anchors ++ Myers diff in a single token"). The hash is short, deterministic, and easy +for an LLM to emit. + +Each patch entry is: + { + "anchor_hash": "a1b2c3d4", # 8-hex-char prefix of sha256 of `old` + "old": " return value\n", + "new": " return value + 1\n", + } + +Resolution: + 1. Compute sha256 of `old` and confirm its 8-char prefix matches anchor_hash. + This catches transcription errors before touching the file. + 2. Search for `old` as a substring in the current file content. + 3. If exactly one match exists, replace it with `new`. Otherwise the patch + entry fails: zero matches means the file changed, multiple matches + means the anchor is ambiguous and `old` needs more context. + +Patches are applied in order. If any entry fails, the file is left untouched +(write is atomic via a tempfile rename). The result lists per-entry status so +the agent can decide which entries to fix and retry. + +Usage: + from strands import Agent + from strands_tools import apply_patch + + agent = Agent(tools=[apply_patch]) + agent.tool.apply_patch( + path="/repo/src/module.py", + patches=[ + { + "anchor_hash": "a1b2c3d4", + "old": "x = 1\n", + "new": "x = 2\n", + }, + ], + ) +""" + +import hashlib +import os +import tempfile +from os.path import expanduser +from typing import Any, Dict, List + +from strands import tool + +ANCHOR_LENGTH = 8 + + +def anchor_hash(text: str) -> str: + """Compute the short hex anchor for a span of text. + + Public so callers can pre-compute anchors when crafting patches without + invoking the tool. + """ + return hashlib.sha256(text.encode("utf-8")).hexdigest()[:ANCHOR_LENGTH] + + +def _apply_one(content: str, patch: Dict[str, Any], index: int) -> Dict[str, Any]: + if not isinstance(patch, dict): + return {"index": index, "status": "error", "error": "Patch entry must be an object."} + + old = patch.get("old") + new = patch.get("new") + declared = patch.get("anchor_hash") + + if not isinstance(old, str) or not isinstance(new, str): + return { + "index": index, + "status": "error", + "error": "Patch entry requires string fields 'old' and 'new'.", + } + if not isinstance(declared, str): + return { + "index": index, + "status": "error", + "error": "Patch entry requires a string 'anchor_hash'.", + } + + expected = anchor_hash(old) + if declared.lower() != expected: + return { + "index": index, + "status": "anchor_mismatch", + "expected_anchor": expected, + "received_anchor": declared, + "error": ("anchor_hash does not match sha256(old)[:8]. The 'old' text was likely transcribed incorrectly."), + } + + occurrences = content.count(old) + if occurrences == 0: + return { + "index": index, + "status": "not_found", + "error": "'old' was not found in the file. The targeted region may have changed.", + } + if occurrences > 1: + return { + "index": index, + "status": "ambiguous", + "matches": occurrences, + "error": "'old' matched more than one location. Add surrounding context to disambiguate.", + } + + return { + "index": index, + "status": "success", + "anchor_hash": expected, + "applied_content": content.replace(old, new, 1), + } + + +def _atomic_write(path: str, content: str) -> None: + directory = os.path.dirname(path) or "." + fd, tmp_path = tempfile.mkstemp(prefix=".apply_patch_", dir=directory) + try: + with os.fdopen(fd, "w", encoding="utf-8") as fh: + fh.write(content) + os.replace(tmp_path, path) + except Exception: + if os.path.exists(tmp_path): + os.remove(tmp_path) + raise + + +@tool +def apply_patch(path: str, patches: List[Dict[str, Any]]) -> Dict[str, Any]: + """Apply one or more hash-anchored edits to a file. + + Each patch carries an `anchor_hash` (the 8-character sha256 prefix of its + `old` text), an `old` span to replace, and the `new` span to substitute. + The edit lands only if the anchor verifies and `old` appears exactly once + in the current file. + + All patches must succeed for the file to be written. On any failure the + file is left unchanged and the failing entries are reported in the result. + + Args: + path: Absolute or user-relative path to the file to edit. Must exist. + patches: Ordered list of patch entries. Each entry must have keys + `anchor_hash`, `old`, and `new`, all strings. + + Returns: + ToolResult dict. The JSON content block carries: + - path: The expanded path. + - applied: count of entries that applied. + - failed: count of entries that did not. + - results: per-entry detail. Failed entries carry status values + "anchor_mismatch", "not_found", "ambiguous", or "error". + - error: present only on overall failure. + + Examples: + >>> apply_patch( + ... path="/repo/file.py", + ... patches=[ + ... {"anchor_hash": "a1b2c3d4", "old": "x=1\\n", "new": "x=2\\n"}, + ... ], + ... ) + """ + expanded = expanduser(path) + + def _wrap(payload: Dict[str, Any], status: str) -> Dict[str, Any]: + applied = payload.get("applied", 0) + failed = payload.get("failed", 0) + if "error" in payload: + text = f"apply_patch: {payload['error']}" + else: + text = f"apply_patch: applied {applied}, failed {failed} on {expanded}" + return { + "status": status, + "content": [{"text": text}, {"json": payload}], + } + + if not isinstance(patches, list) or not patches: + return _wrap( + { + "path": expanded, + "error": "'patches' must be a non-empty list.", + "applied": 0, + "failed": 0, + "results": [], + }, + "error", + ) + + try: + with open(expanded, "r", encoding="utf-8") as fh: + content = fh.read() + except OSError as exc: + return _wrap( + { + "path": expanded, + "error": f"Could not read file: {exc}", + "applied": 0, + "failed": len(patches), + "results": [], + }, + "error", + ) + + results: List[Dict[str, Any]] = [] + working = content + + for index, patch in enumerate(patches): + outcome = _apply_one(working, patch, index) + if outcome["status"] == "success": + working = outcome.pop("applied_content") + results.append(outcome) + + failed = [r for r in results if r["status"] != "success"] + if failed: + return _wrap( + { + "path": expanded, + "applied": 0, + "failed": len(failed), + "results": results, + "error": "One or more patch entries failed; file was not modified.", + }, + "error", + ) + + try: + _atomic_write(expanded, working) + except OSError as exc: + return _wrap( + { + "path": expanded, + "error": f"Could not write file: {exc}", + "applied": 0, + "failed": len(patches), + "results": results, + }, + "error", + ) + + return _wrap( + { + "path": expanded, + "applied": len(results), + "failed": 0, + "results": results, + }, + "success", + ) diff --git a/src/strands_tools/ast_context.py b/src/strands_tools/ast_context.py new file mode 100644 index 00000000..54420417 --- /dev/null +++ b/src/strands_tools/ast_context.py @@ -0,0 +1,246 @@ +"""AST-based context tool for Python source files. + +Returns a structural outline of a Python file: top-level imports, module-level +assignments, classes (with their methods), and free functions, each tagged with +a line range. Lets an agent fetch just the shape of a file without reading the +full body. + +Background: + Coding agents commonly burn context by reading whole files when they only + need to know what is defined where. The Dirac TerminalBench writeup calls + this out as `EXCESSIVE_FILE_READS`. An outline is one cheap call that + answers "what is in this file" so the agent can decide whether to read + further. + +Usage with Strands Agent: + from strands import Agent + from strands_tools import ast_context + + agent = Agent(tools=[ast_context]) + agent.tool.ast_context(path="/path/to/module.py") +""" + +import ast +from os.path import expanduser +from typing import Any, Dict, List, Optional + +from strands import tool + +_MAX_FILE_BYTES = 5 * 1024 * 1024 + + +def _signature(node: ast.AST) -> str: + """Render a one-line signature for a function or async function definition.""" + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + return "" + + args = node.args + posonly = list(args.posonlyargs) + pos_args = list(args.args) + pos_count = len(posonly) + len(pos_args) + pos_defaults: List[Optional[str]] = [None] * (pos_count - len(args.defaults)) + [ + ast.unparse(d) for d in args.defaults + ] + kw_defaults: List[Optional[str]] = [ast.unparse(d) if d is not None else None for d in args.kw_defaults] + + parts: List[str] = [] + + pos_index = 0 + for a in posonly: + default = pos_defaults[pos_index] + parts.append(f"{a.arg}={default}" if default is not None else a.arg) + pos_index += 1 + if posonly: + parts.append("/") + + for a in pos_args: + default = pos_defaults[pos_index] + parts.append(f"{a.arg}={default}" if default is not None else a.arg) + pos_index += 1 + + if args.vararg is not None: + parts.append(f"*{args.vararg.arg}") + elif args.kwonlyargs: + parts.append("*") + + for a, default in zip(args.kwonlyargs, kw_defaults, strict=False): + parts.append(f"{a.arg}={default}" if default is not None else a.arg) + + if args.kwarg is not None: + parts.append(f"**{args.kwarg.arg}") + + prefix = "async def " if isinstance(node, ast.AsyncFunctionDef) else "def " + return f"{prefix}{node.name}({', '.join(parts)})" + + +def _line_range(node: ast.AST) -> List[int]: + start = getattr(node, "lineno", 0) + end = getattr(node, "end_lineno", start) or start + return [start, end] + + +def _outline_function(node: ast.AST) -> Dict[str, Any]: + return { + "kind": "async_function" if isinstance(node, ast.AsyncFunctionDef) else "function", + "name": getattr(node, "name", ""), + "signature": _signature(node), + "lines": _line_range(node), + "decorators": [ast.unparse(d) for d in getattr(node, "decorator_list", [])], + } + + +def _outline_class(node: ast.ClassDef) -> Dict[str, Any]: + methods: List[Dict[str, Any]] = [] + for child in node.body: + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): + methods.append(_outline_function(child)) + + return { + "kind": "class", + "name": node.name, + "lines": _line_range(node), + "bases": [ast.unparse(b) for b in node.bases], + "decorators": [ast.unparse(d) for d in node.decorator_list], + "methods": methods, + } + + +def _outline_imports(tree: ast.Module) -> List[Dict[str, Any]]: + imports: List[Dict[str, Any]] = [] + for node in tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + imports.append( + { + "kind": "import", + "module": alias.name, + "asname": alias.asname, + "line": node.lineno, + } + ) + elif isinstance(node, ast.ImportFrom): + for alias in node.names: + imports.append( + { + "kind": "from_import", + "module": node.module, + "name": alias.name, + "asname": alias.asname, + "level": node.level, + "line": node.lineno, + } + ) + return imports + + +def _outline_assignments(tree: ast.Module) -> List[Dict[str, Any]]: + assignments: List[Dict[str, Any]] = [] + for node in tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): + assignments.append({"name": target.id, "line": node.lineno}) + elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + assignments.append({"name": node.target.id, "line": node.lineno}) + return assignments + + +@tool +def ast_context(path: str, include_imports: bool = True, include_assignments: bool = True) -> Dict[str, Any]: + """Return a structural outline of a Python source file. + + Parses the file with the standard library `ast` module and returns its + top-level shape: imports, module-level assignments, classes (with their + methods), and free functions. Each entry carries a line range so the agent + can follow up with a targeted read of just the relevant span. + + Args: + path: Absolute or user-relative path to a `.py` file. + include_imports: If False, omit the imports section. + include_assignments: If False, omit the module-level assignments section. + + Returns: + ToolResult dict. The JSON content block carries the structured outline: + - path: The expanded path that was read. + - docstring: The module docstring, if any. + - imports: list of `{kind, module, ...}` entries (when included). + - assignments: list of `{name, line}` entries (when included). + - classes: list of class outlines, each with name, line range, + bases, decorators, and method outlines. + - functions: list of free-function outlines. + - error: present only on error, with a human-readable message. + + Raises: + Does not raise. Errors are returned in the `error` field. + + Examples: + >>> ast_context(path="/repo/src/module.py") + {"status": "success", "path": "/repo/src/module.py", "imports": [...], + "classes": [...], "functions": [...]} + """ + expanded = expanduser(path) + + def _error(msg: str) -> Dict[str, Any]: + return { + "status": "error", + "content": [ + {"text": f"ast_context error: {msg}"}, + {"json": {"path": expanded, "error": msg}}, + ], + } + + try: + with open(expanded, "rb") as fh: + raw = fh.read(_MAX_FILE_BYTES + 1) + except OSError as exc: + return _error(f"Could not read file: {exc}") + + if len(raw) > _MAX_FILE_BYTES: + return _error(f"File exceeds {_MAX_FILE_BYTES} bytes; refusing to parse.") + + try: + source = raw.decode("utf-8") + except UnicodeDecodeError as exc: + return _error(f"File is not valid UTF-8: {exc}") + + try: + tree = ast.parse(source, filename=expanded) + except SyntaxError as exc: + return _error(f"SyntaxError at line {exc.lineno}: {exc.msg}") + + classes: List[Dict[str, Any]] = [] + functions: List[Dict[str, Any]] = [] + for node in tree.body: + if isinstance(node, ast.ClassDef): + classes.append(_outline_class(node)) + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + functions.append(_outline_function(node)) + + docstring: Optional[str] = ast.get_docstring(tree) + + payload: Dict[str, Any] = { + "path": expanded, + "docstring": docstring, + "classes": classes, + "functions": functions, + } + if include_imports: + payload["imports"] = _outline_imports(tree) + if include_assignments: + payload["assignments"] = _outline_assignments(tree) + + summary_lines = [f"Outline of {expanded}"] + if include_imports: + summary_lines.append(f" imports: {len(payload['imports'])}") + if include_assignments: + summary_lines.append(f" assignments: {len(payload['assignments'])}") + summary_lines.append(f" classes: {len(classes)}") + summary_lines.append(f" functions: {len(functions)}") + + return { + "status": "success", + "content": [ + {"text": "\n".join(summary_lines)}, + {"json": payload}, + ], + } diff --git a/tests/test_apply_patch.py b/tests/test_apply_patch.py new file mode 100644 index 00000000..5c0acb8f --- /dev/null +++ b/tests/test_apply_patch.py @@ -0,0 +1,133 @@ +"""Tests for the apply_patch tool.""" + +import pytest +from strands import Agent + +from strands_tools import apply_patch +from strands_tools.apply_patch import anchor_hash + + +@pytest.fixture +def agent(): + return Agent(tools=[apply_patch]) + + +@pytest.fixture +def file_with_text(tmp_path): + path = tmp_path / "file.py" + path.write_text("alpha\nbeta\ngamma\n") + return str(path) + + +def _payload(tool_response): + for block in tool_response["content"]: + if "json" in block: + return block["json"] + raise AssertionError("No JSON content block in tool response") + + +def _read(path): + with open(path, "r", encoding="utf-8") as fh: + return fh.read() + + +def test_single_patch_applies(agent, file_with_text): + old = "beta\n" + response = agent.tool.apply_patch( + path=file_with_text, + patches=[{"anchor_hash": anchor_hash(old), "old": old, "new": "BETA\n"}], + ) + + assert response["status"] == "success" + assert _payload(response)["applied"] == 1 + assert _read(file_with_text) == "alpha\nBETA\ngamma\n" + + +def test_multiple_patches_applied_in_order(agent, file_with_text): + p1_old, p1_new = "alpha\n", "ALPHA\n" + p2_old, p2_new = "gamma\n", "GAMMA\n" + + response = agent.tool.apply_patch( + path=file_with_text, + patches=[ + {"anchor_hash": anchor_hash(p1_old), "old": p1_old, "new": p1_new}, + {"anchor_hash": anchor_hash(p2_old), "old": p2_old, "new": p2_new}, + ], + ) + + assert response["status"] == "success" + assert _payload(response)["applied"] == 2 + assert _read(file_with_text) == "ALPHA\nbeta\nGAMMA\n" + + +def test_anchor_mismatch_leaves_file_untouched(agent, file_with_text): + response = agent.tool.apply_patch( + path=file_with_text, + patches=[{"anchor_hash": "deadbeef", "old": "beta\n", "new": "BETA\n"}], + ) + + assert response["status"] == "error" + assert _payload(response)["results"][0]["status"] == "anchor_mismatch" + assert _read(file_with_text) == "alpha\nbeta\ngamma\n" + + +def test_not_found_leaves_file_untouched(agent, file_with_text): + missing = "delta\n" + response = agent.tool.apply_patch( + path=file_with_text, + patches=[{"anchor_hash": anchor_hash(missing), "old": missing, "new": "DELTA\n"}], + ) + + assert response["status"] == "error" + assert _payload(response)["results"][0]["status"] == "not_found" + assert _read(file_with_text) == "alpha\nbeta\ngamma\n" + + +def test_ambiguous_match_leaves_file_untouched(agent, tmp_path): + path = tmp_path / "f.py" + path.write_text("x\nx\ny\n") + + response = agent.tool.apply_patch( + path=str(path), + patches=[{"anchor_hash": anchor_hash("x\n"), "old": "x\n", "new": "X\n"}], + ) + + assert response["status"] == "error" + payload = _payload(response) + assert payload["results"][0]["status"] == "ambiguous" + assert payload["results"][0]["matches"] == 2 + assert _read(str(path)) == "x\nx\ny\n" + + +def test_failure_in_second_patch_aborts_all(agent, file_with_text): + good_old, good_new = "alpha\n", "ALPHA\n" + bad_old = "missing\n" + + response = agent.tool.apply_patch( + path=file_with_text, + patches=[ + {"anchor_hash": anchor_hash(good_old), "old": good_old, "new": good_new}, + {"anchor_hash": anchor_hash(bad_old), "old": bad_old, "new": "X\n"}, + ], + ) + + assert response["status"] == "error" + payload = _payload(response) + assert payload["results"][0]["status"] == "success" + assert payload["results"][1]["status"] == "not_found" + assert _read(file_with_text) == "alpha\nbeta\ngamma\n" + + +def test_empty_patches_is_an_error(agent, file_with_text): + response = agent.tool.apply_patch(path=file_with_text, patches=[]) + assert response["status"] == "error" + assert "non-empty" in _payload(response)["error"] + + +def test_missing_file_returns_error(agent, tmp_path): + response = agent.tool.apply_patch( + path=str(tmp_path / "nope.py"), + patches=[{"anchor_hash": anchor_hash("x"), "old": "x", "new": "y"}], + ) + assert response["status"] == "error" + assert "Could not read" in _payload(response)["error"] diff --git a/tests/test_ast_context.py b/tests/test_ast_context.py new file mode 100644 index 00000000..9766941b --- /dev/null +++ b/tests/test_ast_context.py @@ -0,0 +1,106 @@ +"""Tests for the ast_context tool.""" + +import pytest +from strands import Agent + +from strands_tools import ast_context + + +@pytest.fixture +def agent(): + return Agent(tools=[ast_context]) + + +@pytest.fixture +def py_file(tmp_path): + path = tmp_path / "module.py" + path.write_text( + '"""A short module."""\n' + "import os\n" + "from collections import defaultdict as dd\n" + "\n" + "CONST: int = 1\n" + "other = 2\n" + "\n" + "@staticmethod\n" + "def hello(name, *args, **kwargs):\n" + ' """Say hi."""\n' + " return name\n" + "\n" + "async def fetch(url, /, retries=3):\n" + " return None\n" + "\n" + "class Greeter(Base, metaclass=Meta):\n" + " def __init__(self, x):\n" + " self.x = x\n" + " async def greet(self, name):\n" + " return name\n" + ) + return str(path) + + +def _payload(tool_response): + for block in tool_response["content"]: + if "json" in block: + return block["json"] + raise AssertionError("No JSON content block in tool response") + + +def test_outline_basic(agent, py_file): + response = agent.tool.ast_context(path=py_file) + assert response["status"] == "success" + payload = _payload(response) + + assert payload["docstring"] == "A short module." + + assert {imp["module"] for imp in payload["imports"] if imp["kind"] == "import"} == {"os"} + from_imports = [imp for imp in payload["imports"] if imp["kind"] == "from_import"] + assert from_imports[0]["module"] == "collections" + assert from_imports[0]["name"] == "defaultdict" + assert from_imports[0]["asname"] == "dd" + + assert {a["name"] for a in payload["assignments"]} == {"CONST", "other"} + + funcs = {f["name"]: f for f in payload["functions"]} + assert "hello" in funcs and "fetch" in funcs + assert funcs["hello"]["signature"] == "def hello(name, *args, **kwargs)" + assert funcs["fetch"]["kind"] == "async_function" + assert funcs["fetch"]["signature"] == "async def fetch(url, /, retries=3)" + assert funcs["hello"]["decorators"] == ["staticmethod"] + + classes = {c["name"]: c for c in payload["classes"]} + assert "Greeter" in classes + method_names = {m["name"] for m in classes["Greeter"]["methods"]} + assert method_names == {"__init__", "greet"} + assert "Base" in classes["Greeter"]["bases"] + + +def test_line_ranges_make_sense(agent, py_file): + response = agent.tool.ast_context(path=py_file) + payload = _payload(response) + greeter = next(c for c in payload["classes"] if c["name"] == "Greeter") + start, end = greeter["lines"] + assert start < end + + +def test_optional_sections_can_be_omitted(agent, py_file): + response = agent.tool.ast_context(path=py_file, include_imports=False, include_assignments=False) + payload = _payload(response) + assert "imports" not in payload + assert "assignments" not in payload + assert "classes" in payload and "functions" in payload + + +def test_syntax_error_is_reported(agent, tmp_path): + bad = tmp_path / "bad.py" + bad.write_text("def broken(:\n pass\n") + + response = agent.tool.ast_context(path=str(bad)) + assert response["status"] == "error" + assert "SyntaxError" in _payload(response)["error"] + + +def test_missing_file_is_reported(agent, tmp_path): + response = agent.tool.ast_context(path=str(tmp_path / "does_not_exist.py")) + assert response["status"] == "error" + assert "Could not read" in _payload(response)["error"]