From 99d16524093b6c5d7667c5845b13be6f785868c6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 2 Apr 2026 04:44:02 +0000 Subject: [PATCH 01/21] Add Memory capability with pluggable storage backends Implements a Memory capability (AbstractCapability subclass) for persistent key-value memory across agent sessions, addressing #30. - MemoryStore protocol with InMemoryStore (dict-based, for testing) and FileStore (JSON file on disk, for persistence) backends - Five tools via get_toolset(): save_memory, recall_memory, search_memories, list_memories, delete_memory - Dynamic instructions via get_instructions() that inject stored memories into the system prompt at run start - Substring-based search across keys, content, and tags - Spec serialization support (Memory.from_spec with backend="memory"|"file") - 48 tests covering all code paths, passing lint, format, and typecheck Co-Authored-By: Claude Opus 4.6 (1M context) --- PLAN.md | 51 ++++ src/pydantic_harness/__init__.py | 10 +- src/pydantic_harness/memory.py | 323 ++++++++++++++++++++++++++ tests/test_memory.py | 387 +++++++++++++++++++++++++++++++ 4 files changed, 770 insertions(+), 1 deletion(-) create mode 100644 PLAN.md create mode 100644 src/pydantic_harness/memory.py create mode 100644 tests/test_memory.py diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..ca533d8 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,51 @@ +# Memory Capability + +## Summary + +Implements a `Memory` capability (`AbstractCapability` subclass) that provides persistent key-value memory across agent sessions, referencing issues #30 and #31. + +## Design + +### Architecture + +- **`Memory`** dataclass extends `AbstractCapability[AgentDepsT]` + - `get_instructions()` returns a dynamic callable that injects stored memories into the system prompt at run start + - `get_toolset()` returns a `FunctionToolset` with five tools: `save_memory`, `recall_memory`, `search_memories`, `list_memories`, `delete_memory` + - Tool functions use closures over `self.store` (no dependency on agent `deps`) + +### Storage + +- **`MemoryStore`** protocol: pluggable backend with `get`, `put`, `delete`, `list_all`, `search` +- **`InMemoryStore`**: dict-based, ephemeral, for testing (default) +- **`FileStore`**: JSON file on disk, reads on init, writes on every mutation + +### Memory Model + +- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `created_at`, `updated_at` +- Search is substring-based (case-insensitive) across key, content, and tags + +### Spec Serialization + +- `Memory.get_serialization_name()` returns `"Memory"` +- `Memory.from_spec(backend="file", path="...")` creates a `FileStore`-backed instance + +## Configuration + +| Field | Default | Description | +|-------|---------|-------------| +| `store` | `InMemoryStore()` | Storage backend | +| `inject_memories_in_instructions` | `True` | Include memories in system prompt | +| `max_instructions_memories` | `20` | Cap on memories injected into prompt | + +## Files + +- `src/pydantic_harness/memory.py` - Capability, stores, entry model +- `src/pydantic_harness/__init__.py` - Re-exports +- `tests/test_memory.py` - 48 tests covering all code paths + +## Future Work + +- Semantic/vector search backend (e.g. embedding-based `MemoryStore`) +- TTL / expiration on entries +- Session-scoped memory isolation via `for_run()` +- SQLite / Redis backends for production persistence diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..62831af 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,12 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryStore + +__all__: list[str] = [ + 'FileStore', + 'InMemoryStore', + 'Memory', + 'MemoryEntry', + 'MemoryStore', +] diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py new file mode 100644 index 0000000..8e8db26 --- /dev/null +++ b/src/pydantic_harness/memory.py @@ -0,0 +1,323 @@ +"""Memory capability for persistent agent memory across sessions. + +Provides tools for saving, recalling, searching, listing, and deleting +key-value memories, with pluggable storage backends (`InMemoryStore` for +testing, `FileStore` for on-disk persistence). +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Protocol, runtime_checkable + +from pydantic_ai._instructions import AgentInstructions +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.tools import AgentDepsT, RunContext, Tool +from pydantic_ai.toolsets import AgentToolset +from pydantic_ai.toolsets.function import FunctionToolset + + +@dataclass +class MemoryEntry: + """A single memory entry with content, tags, and timestamps.""" + + key: str + """Unique identifier for this memory.""" + + content: str + """The content of the memory.""" + + tags: list[str] = field(default_factory=list[str]) + """Optional tags for categorization and search.""" + + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of when the memory was first created.""" + + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of the last update.""" + + def to_dict(self) -> dict[str, Any]: + """Serialize to a plain dict for JSON storage.""" + return { + 'key': self.key, + 'content': self.content, + 'tags': self.tags, + 'created_at': self.created_at, + 'updated_at': self.updated_at, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MemoryEntry: + """Deserialize from a plain dict.""" + return cls( + key=data['key'], + content=data['content'], + tags=data.get('tags', []), + created_at=data.get('created_at', ''), + updated_at=data.get('updated_at', ''), + ) + + +@runtime_checkable +class MemoryStore(Protocol): + """Protocol for pluggable memory storage backends.""" + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key, or None if not found.""" + ... + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + ... + + def delete(self, key: str) -> bool: + """Delete a memory entry by key. Returns True if it existed.""" + ... + + def list_all(self) -> list[MemoryEntry]: + """Return all stored memory entries.""" + ... + + def search(self, query: str) -> list[MemoryEntry]: + """Search entries by substring match on key, content, or tags.""" + ... + + +class InMemoryStore: + """Dict-based in-memory store, suitable for testing. + + All data lives in a plain ``dict`` and is lost when the process exits. + """ + + def __init__(self) -> None: + """Initialize an empty in-memory store.""" + self._entries: dict[str, MemoryEntry] = {} + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key.""" + return self._entries.get(key) + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + self._entries[entry.key] = entry + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + return self._entries.pop(key, None) is not None + + def list_all(self) -> list[MemoryEntry]: + """Return all stored memory entries.""" + return list(self._entries.values()) + + def search(self, query: str) -> list[MemoryEntry]: + """Search entries by substring match on key, content, or tags.""" + q = query.lower() + return [ + entry + for entry in self._entries.values() + if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + ] + + +class FileStore: + """JSON-file-based store for simple on-disk persistence. + + Reads the file on initialization and writes back on every mutation. + """ + + def __init__(self, path: str | Path) -> None: + """Initialize a file-backed store at the given path.""" + self._path = Path(path) + self._entries: dict[str, MemoryEntry] = {} + self._load() + + def _load(self) -> None: + if self._path.exists(): + raw: dict[str, Any] = json.loads(self._path.read_text(encoding='utf-8')) + self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + + def _save(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + data = {key: entry.to_dict() for key, entry in self._entries.items()} + self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a memory entry by key.""" + return self._entries.get(key) + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + self._entries[entry.key] = entry + self._save() + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + existed = self._entries.pop(key, None) is not None + if existed: + self._save() + return existed + + def list_all(self) -> list[MemoryEntry]: + """Return all stored memory entries.""" + return list(self._entries.values()) + + def search(self, query: str) -> list[MemoryEntry]: + """Search entries by substring match on key, content, or tags.""" + q = query.lower() + return [ + entry + for entry in self._entries.values() + if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + ] + + +def format_entry(entry: MemoryEntry) -> str: + """Format a memory entry as a human-readable string.""" + line = f'[{entry.key}] {entry.content}' + if entry.tags: + line += f' (tags: {", ".join(entry.tags)})' + return line + + +@dataclass +class Memory(AbstractCapability[AgentDepsT]): + """Capability for persistent memory across agent sessions. + + Provides tools for saving, recalling, searching, listing, and deleting + key-value memories. Uses a pluggable `MemoryStore` backend for storage. + + Example: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness.memory import Memory, InMemoryStore + + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=InMemoryStore())]) + ``` + """ + + store: MemoryStore = field(default_factory=InMemoryStore) + """The storage backend. Defaults to `InMemoryStore` (ephemeral, dict-based).""" + + inject_memories_in_instructions: bool = True + """Whether to inject existing memories into the system prompt at run start.""" + + max_instructions_memories: int = 20 + """Maximum number of memories to include in the system prompt.""" + + @classmethod + def get_serialization_name(cls) -> str | None: + """Return the name used for spec serialization.""" + return 'Memory' + + @classmethod + def from_spec(cls, *args: Any, **kwargs: Any) -> Memory[Any]: + """Create from spec arguments. + + Supports `backend` kwarg: ``"memory"`` (default) or ``"file"`` (requires `path`). + """ + backend = kwargs.pop('backend', 'memory') + if backend == 'file': + path = kwargs.pop('path', '.memories.json') + return cls(store=FileStore(path), **kwargs) + return cls(store=InMemoryStore(), **kwargs) + + def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: + """Build dynamic instructions that include currently stored memories.""" + parts: list[str] = [ + 'You have access to a persistent memory system. ' + 'Use it to save important information that should be remembered across conversations.', + ] + if self.inject_memories_in_instructions: + entries = self.store.list_all() + if entries: + parts.append('\nCurrently stored memories:') + for entry in entries[: self.max_instructions_memories]: + parts.append(f'- {format_entry(entry)}') + overflow = len(entries) - self.max_instructions_memories + if overflow > 0: + parts.append(f'... and {overflow} more (use list_memories or search_memories to see all).') + return '\n'.join(parts) + + def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: + """Return dynamic instructions that include stored memories.""" + return self.build_instructions + + def get_toolset(self) -> AgentToolset[AgentDepsT] | None: + """Return a toolset with memory management tools. + + Tool functions close over ``self`` to access the store without + requiring anything from the agent's ``deps``. + """ + store = self.store + + def save_memory(key: str, content: str, tags: list[str] | None = None) -> str: + """Save or update a memory entry. + + Args: + key: Unique key for this memory. + content: The content to remember. + tags: Optional tags for categorization and search. + """ + now = datetime.now(timezone.utc).isoformat() + existing = store.get(key) + entry = MemoryEntry( + key=key, + content=content, + tags=tags or [], + created_at=existing.created_at if existing else now, + updated_at=now, + ) + store.put(entry) + return f'Memory saved: {key}' + + def recall_memory(key: str) -> str: + """Recall a specific memory by its key. + + Args: + key: The key of the memory to recall. + """ + entry = store.get(key) + if entry is None: + return f'No memory found for key: {key}' + return format_entry(entry) + + def search_memories(query: str) -> str: + """Search memories by substring match on keys, content, or tags. + + Args: + query: The search query string. + """ + results = store.search(query) + if not results: + return f'No memories found matching: {query}' + return '\n'.join(format_entry(entry) for entry in results) + + def list_memories() -> str: + """List all stored memories.""" + entries = store.list_all() + if not entries: + return 'No memories stored.' + return '\n'.join(format_entry(entry) for entry in entries) + + def delete_memory(key: str) -> str: + """Delete a memory by its key. + + Args: + key: The key of the memory to delete. + """ + if store.delete(key): + return f'Memory deleted: {key}' + return f'No memory found for key: {key}' + + return FunctionToolset( + [ + Tool(save_memory, takes_ctx=False), + Tool(recall_memory, takes_ctx=False), + Tool(search_memories, takes_ctx=False), + Tool(list_memories, takes_ctx=False), + Tool(delete_memory, takes_ctx=False), + ], + ) diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..890bf00 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,387 @@ +"""Tests for the Memory capability.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +from pydantic_ai._run_context import RunContext +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import RunUsage + +from pydantic_harness.memory import ( + FileStore, + InMemoryStore, + Memory, + MemoryEntry, + MemoryStore, + format_entry, +) + +# --- MemoryEntry --- + + +class TestMemoryEntry: + def test_round_trip(self) -> None: + entry = MemoryEntry(key='k', content='v', tags=['a', 'b'], created_at='t1', updated_at='t2') + assert MemoryEntry.from_dict(entry.to_dict()) == entry + + def test_from_dict_defaults(self) -> None: + entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) + assert entry.tags == [] + assert entry.created_at == '' + assert entry.updated_at == '' + + def test_default_timestamps(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.created_at # non-empty ISO string + assert entry.updated_at + + +# --- InMemoryStore --- + + +class TestInMemoryStore: + def test_put_and_get(self) -> None: + store = InMemoryStore() + entry = MemoryEntry(key='greeting', content='hello') + store.put(entry) + assert store.get('greeting') is entry + + def test_get_missing(self) -> None: + store = InMemoryStore() + assert store.get('nope') is None + + def test_put_overwrites(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v1')) + store.put(MemoryEntry(key='k', content='v2')) + result = store.get('k') + assert result is not None + assert result.content == 'v2' + + def test_delete_existing(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.delete('k') is True + assert store.get('k') is None + + def test_delete_missing(self) -> None: + store = InMemoryStore() + assert store.delete('nope') is False + + def test_list_all_empty(self) -> None: + store = InMemoryStore() + assert store.list_all() == [] + + def test_list_all(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + entries = store.list_all() + assert len(entries) == 2 + assert {e.key for e in entries} == {'a', 'b'} + + def test_search_by_key(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + store.put(MemoryEntry(key='color', content='blue')) + results = store.search('user') + assert len(results) == 1 + assert results[0].key == 'user_name' + + def test_search_by_content(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k1', content='the quick brown fox')) + store.put(MemoryEntry(key='k2', content='lazy dog')) + results = store.search('fox') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_by_tag(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k1', content='x', tags=['important'])) + store.put(MemoryEntry(key='k2', content='y', tags=['trivial'])) + results = store.search('important') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_case_insensitive(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='K1', content='Hello World')) + results = store.search('hello') + assert len(results) == 1 + + def test_search_no_results(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('zzz') == [] + + +# --- FileStore --- + + +class TestFileStore: + def test_put_and_get(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.get('k') is not None + assert store.get('k').content == 'v' # type: ignore[union-attr] + + def test_persistence(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='persisted')) + + # New store instance should load from disk + store2 = FileStore(path) + result = store2.get('k') + assert result is not None + assert result.content == 'persisted' + + def test_delete_saves(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + store.delete('k') + + # Reload and verify deletion persisted + store2 = FileStore(path) + assert store2.get('k') is None + + def test_list_all(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + assert len(store.list_all()) == 2 + + def test_search(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k1', content='hello', tags=['greeting'])) + store.put(MemoryEntry(key='k2', content='world')) + assert len(store.search('greeting')) == 1 + assert len(store.search('hello')) == 1 + + def test_empty_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + # File does not exist yet + store = FileStore(path) + assert store.list_all() == [] + + def test_creates_parent_dirs(self, tmp_path: Path) -> None: + path = tmp_path / 'sub' / 'dir' / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert path.exists() + + def test_file_format(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v', tags=['t'], created_at='c', updated_at='u')) + raw = json.loads(path.read_text()) + assert raw == { + 'k': { + 'key': 'k', + 'content': 'v', + 'tags': ['t'], + 'created_at': 'c', + 'updated_at': 'u', + } + } + + +# --- format_entry --- + + +class TestFormatEntry: + def test_no_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert format_entry(entry) == '[k] hello' + + def test_with_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) + assert format_entry(entry) == '[k] hello (tags: a, b)' + + +# --- Memory capability --- + + +class TestMemoryCapability: + def test_serialization_name(self) -> None: + assert Memory.get_serialization_name() == 'Memory' + + def test_from_spec_default(self) -> None: + cap = Memory.from_spec() + assert isinstance(cap.store, InMemoryStore) + + def test_from_spec_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + cap = Memory.from_spec(backend='file', path=str(path)) + assert isinstance(cap.store, FileStore) + + def test_default_store(self) -> None: + cap: Memory[None] = Memory() + assert isinstance(cap.store, InMemoryStore) + + def test_get_toolset_returns_function_toolset(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + + def test_toolset_has_expected_tools(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + tool_names = set(toolset.tools.keys()) + assert tool_names == {'save_memory', 'recall_memory', 'search_memories', 'list_memories', 'delete_memory'} + + +# --- Tool functions (via closure) --- + + +class TestMemoryTools: + """Test the tool functions exposed by the Memory capability.""" + + @staticmethod + def _get_tools(store: InMemoryStore | None = None) -> dict[str, Any]: + cap: Memory[None] = Memory(store=store or InMemoryStore()) + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + return {name: tool.function for name, tool in toolset.tools.items()} + + def test_save_and_recall(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + result = tools['save_memory']('greeting', 'hello world') + assert result == 'Memory saved: greeting' + + recalled = tools['recall_memory']('greeting') + assert '[greeting] hello world' in recalled + + def test_recall_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['recall_memory']('nope') + + def test_save_updates_existing(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v1') + original = store.get('k') + assert original is not None + original_created = original.created_at + + tools['save_memory']('k', 'v2') + updated = store.get('k') + assert updated is not None + assert updated.content == 'v2' + # created_at should be preserved + assert updated.created_at == original_created + + def test_save_with_tags(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', ['tag1', 'tag2']) + entry = store.get('k') + assert entry is not None + assert entry.tags == ['tag1', 'tag2'] + + def test_search(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('user_name', 'Alice') + tools['save_memory']('color', 'blue') + + result = tools['search_memories']('alice') + assert 'Alice' in result + assert 'blue' not in result + + def test_search_no_results(self) -> None: + tools = self._get_tools() + assert 'No memories found' in tools['search_memories']('zzz') + + def test_list_empty(self) -> None: + tools = self._get_tools() + assert tools['list_memories']() == 'No memories stored.' + + def test_list_with_entries(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha') + tools['save_memory']('b', 'beta') + result = tools['list_memories']() + assert '[a] alpha' in result + assert '[b] beta' in result + + def test_delete_existing(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v') + assert tools['delete_memory']('k') == 'Memory deleted: k' + assert store.get('k') is None + + def test_delete_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['delete_memory']('nope') + + +# --- Instructions --- + + +class TestMemoryInstructions: + @staticmethod + def _make_ctx() -> RunContext[None]: + from unittest.mock import MagicMock + + return RunContext( + deps=None, + model=MagicMock(), + usage=RunUsage(), + ) + + def test_get_instructions_is_callable(self) -> None: + cap: Memory[None] = Memory() + assert callable(cap.get_instructions()) + + def test_instructions_with_no_memories(self) -> None: + cap: Memory[None] = Memory() + text = cap.build_instructions(self._make_ctx()) + assert 'persistent memory system' in text + assert 'Currently stored memories' not in text + + def test_instructions_with_memories(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='user', content='Alice')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' in text + assert '[user] Alice' in text + + def test_instructions_respects_max(self) -> None: + store = InMemoryStore() + for i in range(25): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and 20 more' in text + + def test_instructions_disabled(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + cap: Memory[None] = Memory(store=store, inject_memories_in_instructions=False) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' not in text + + +# --- MemoryStore protocol --- + + +class TestMemoryStoreProtocol: + def test_in_memory_store_satisfies_protocol(self) -> None: + assert isinstance(InMemoryStore(), MemoryStore) + + def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: + assert isinstance(FileStore(tmp_path / 'mem.json'), MemoryStore) From 6feffca8435b70896d7d8877c3dd160787d38479 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 2 Apr 2026 05:50:03 +0000 Subject: [PATCH 02/21] Improve Memory capability: better search, scoping, TTL, dedup warning Address audit findings from PR review: - Better search: word-boundary matching with relevance scoring (count of matching words across key/content/tags, sorted by score descending). Underscores and hyphens treated as word separators. - Memory scoping: `scope: str = 'global'` field on MemoryEntry, with optional `scope` parameter on `search_memories` and `list_memories` tools and `list_all`/`search` store methods. - TTL/expiration: `expires_at: str | None = None` on MemoryEntry with `is_expired()` method. Stores filter out expired entries automatically. `save_memory` tool accepts optional `ttl_minutes` parameter. - Dedup warning: when saving a memory whose key is very similar to an existing key (same 10-char prefix, Levenshtein distance <= 2), log a warning via the `pydantic_harness.memory` logger. Tests: 48 -> 99, all passing with 100% coverage. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 206 +++++++++++++++++---- tests/test_memory.py | 314 ++++++++++++++++++++++++++++++++- 2 files changed, 482 insertions(+), 38 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 8e8db26..5649265 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -8,8 +8,10 @@ from __future__ import annotations import json +import logging +import re from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any, Protocol, runtime_checkable @@ -19,6 +21,8 @@ from pydantic_ai.toolsets import AgentToolset from pydantic_ai.toolsets.function import FunctionToolset +logger = logging.getLogger(__name__) + @dataclass class MemoryEntry: @@ -33,18 +37,32 @@ class MemoryEntry: tags: list[str] = field(default_factory=list[str]) """Optional tags for categorization and search.""" + scope: str = 'global' + """Namespace scope for this memory (default ``'global'``).""" + + expires_at: str | None = None + """Optional ISO 8601 expiration timestamp. ``None`` means no expiry.""" + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of when the memory was first created.""" updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of the last update.""" + def is_expired(self) -> bool: + """Return True if this entry has passed its expiration time.""" + if self.expires_at is None: + return False + return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) + def to_dict(self) -> dict[str, Any]: """Serialize to a plain dict for JSON storage.""" return { 'key': self.key, 'content': self.content, 'tags': self.tags, + 'scope': self.scope, + 'expires_at': self.expires_at, 'created_at': self.created_at, 'updated_at': self.updated_at, } @@ -56,33 +74,86 @@ def from_dict(cls, data: dict[str, Any]) -> MemoryEntry: key=data['key'], content=data['content'], tags=data.get('tags', []), + scope=data.get('scope', 'global'), + expires_at=data.get('expires_at'), created_at=data.get('created_at', ''), updated_at=data.get('updated_at', ''), ) +def _score_entry(entry: MemoryEntry, words: list[str]) -> int: + r"""Score a memory entry by counting word-boundary matches across fields. + + Each query word that appears as a whole word (case-insensitive) in the + key, content, or any tag contributes one point per field it appears in. + Underscores and hyphens are treated as word separators in addition to + the standard ``\b`` boundaries. + """ + score = 0 + for word in words: + # Use a boundary pattern that also treats _ and - as separators. + escaped = re.escape(word) + pattern = re.compile(rf'(? bool: + """Return True if two keys share the same first 10 characters and differ only slightly. + + Uses a simple character-level edit distance check: keys are considered + similar when they share the same 10-char prefix and differ by at most 2 + characters (Levenshtein-like). + """ + if len(a) < 10 or len(b) < 10: + return False + if a[:10] != b[:10]: + return False + if a == b: + return False + # Simple Levenshtein-like check: allow at most 2 edits + if abs(len(a) - len(b)) > 2: + return False + # Bounded character-level distance (sufficient for dedup warnings) + max_edits = 2 + m, n = len(a), len(b) + prev = list(range(n + 1)) + for i in range(1, m + 1): + curr = [i] + [0] * n + for j in range(1, n + 1): + cost = 0 if a[i - 1] == b[j - 1] else 1 + curr[j] = min(curr[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost) + prev = curr + return prev[n] <= max_edits + + @runtime_checkable class MemoryStore(Protocol): """Protocol for pluggable memory storage backends.""" - def get(self, key: str) -> MemoryEntry | None: + def get(self, key: str) -> MemoryEntry | None: # pragma: no cover """Retrieve a memory entry by key, or None if not found.""" ... - def put(self, entry: MemoryEntry) -> None: + def put(self, entry: MemoryEntry) -> None: # pragma: no cover """Store or update a memory entry.""" ... - def delete(self, key: str) -> bool: + def delete(self, key: str) -> bool: # pragma: no cover """Delete a memory entry by key. Returns True if it existed.""" ... - def list_all(self) -> list[MemoryEntry]: - """Return all stored memory entries.""" + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + """Return all non-expired entries, optionally filtered by scope.""" ... - def search(self, query: str) -> list[MemoryEntry]: - """Search entries by substring match on key, content, or tags.""" + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + """Search non-expired entries with word-boundary matching, sorted by relevance.""" ... @@ -108,19 +179,31 @@ def delete(self, key: str) -> bool: """Delete a memory entry by key.""" return self._entries.pop(key, None) is not None - def list_all(self) -> list[MemoryEntry]: - """Return all stored memory entries.""" - return list(self._entries.values()) - - def search(self, query: str) -> list[MemoryEntry]: - """Search entries by substring match on key, content, or tags.""" - q = query.lower() + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by scope.""" return [ entry for entry in self._entries.values() - if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + if not entry.is_expired() and (scope is None or entry.scope == scope) ] + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: + """Search non-expired entries with word-boundary matching, sorted by relevance.""" + words = query.lower().split() + if not words: + return [] + scored: list[tuple[int, MemoryEntry]] = [] + for entry in self._entries.values(): + if entry.is_expired(): + continue + if scope is not None and entry.scope != scope: + continue + score = _score_entry(entry, words) + if score > 0: + scored.append((score, entry)) + scored.sort(key=lambda pair: pair[0], reverse=True) + return [entry for _, entry in scored] + class FileStore: """JSON-file-based store for simple on-disk persistence. @@ -160,25 +243,44 @@ def delete(self, key: str) -> bool: self._save() return existed - def list_all(self) -> list[MemoryEntry]: - """Return all stored memory entries.""" - return list(self._entries.values()) - - def search(self, query: str) -> list[MemoryEntry]: - """Search entries by substring match on key, content, or tags.""" - q = query.lower() + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by scope.""" return [ entry for entry in self._entries.values() - if q in entry.key.lower() or q in entry.content.lower() or any(q in tag.lower() for tag in entry.tags) + if not entry.is_expired() and (scope is None or entry.scope == scope) ] + def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: + """Search non-expired entries with word-boundary matching, sorted by relevance.""" + words = query.lower().split() + if not words: + return [] + scored: list[tuple[int, MemoryEntry]] = [] + for entry in self._entries.values(): + if entry.is_expired(): + continue + if scope is not None and entry.scope != scope: + continue + score = _score_entry(entry, words) + if score > 0: + scored.append((score, entry)) + scored.sort(key=lambda pair: pair[0], reverse=True) + return [entry for _, entry in scored] + def format_entry(entry: MemoryEntry) -> str: """Format a memory entry as a human-readable string.""" line = f'[{entry.key}] {entry.content}' + extras: list[str] = [] if entry.tags: - line += f' (tags: {", ".join(entry.tags)})' + extras.append(f'tags: {", ".join(entry.tags)}') + if entry.scope != 'global': + extras.append(f'scope: {entry.scope}') + if entry.expires_at is not None: + extras.append(f'expires: {entry.expires_at}') + if extras: + line += f' ({"; ".join(extras)})' return line @@ -253,22 +355,47 @@ def get_toolset(self) -> AgentToolset[AgentDepsT] | None: """ store = self.store - def save_memory(key: str, content: str, tags: list[str] | None = None) -> str: + def save_memory( + key: str, + content: str, + tags: list[str] | None = None, + scope: str = 'global', + ttl_minutes: int | None = None, + ) -> str: """Save or update a memory entry. Args: key: Unique key for this memory. content: The content to remember. tags: Optional tags for categorization and search. + scope: Namespace scope (default ``'global'``). + ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(timezone.utc) + now_iso = now.isoformat() existing = store.get(key) + + # Dedup warning: check for similar keys among existing entries + for existing_entry in store.list_all(): + if _simple_similarity(key, existing_entry.key): + logger.warning( + 'New memory key %r is very similar to existing key %r — possible duplicate', + key, + existing_entry.key, + ) + + expires_at: str | None = None + if ttl_minutes is not None: + expires_at = (now + timedelta(minutes=ttl_minutes)).isoformat() + entry = MemoryEntry( key=key, content=content, tags=tags or [], - created_at=existing.created_at if existing else now, - updated_at=now, + scope=scope, + expires_at=expires_at, + created_at=existing.created_at if existing else now_iso, + updated_at=now_iso, ) store.put(entry) return f'Memory saved: {key}' @@ -282,22 +409,29 @@ def recall_memory(key: str) -> str: entry = store.get(key) if entry is None: return f'No memory found for key: {key}' + if entry.is_expired(): + return f'No memory found for key: {key}' return format_entry(entry) - def search_memories(query: str) -> str: - """Search memories by substring match on keys, content, or tags. + def search_memories(query: str, scope: str | None = None) -> str: + """Search memories by word-boundary matching on keys, content, or tags, sorted by relevance. Args: - query: The search query string. + query: The search query string (space-separated words). + scope: Optional scope to restrict the search to. """ - results = store.search(query) + results = store.search(query, scope=scope) if not results: return f'No memories found matching: {query}' return '\n'.join(format_entry(entry) for entry in results) - def list_memories() -> str: - """List all stored memories.""" - entries = store.list_all() + def list_memories(scope: str | None = None) -> str: + """List all stored memories, optionally filtered by scope. + + Args: + scope: Optional scope to filter by. + """ + entries = store.list_all(scope=scope) if not entries: return 'No memories stored.' return '\n'.join(format_entry(entry) for entry in entries) diff --git a/tests/test_memory.py b/tests/test_memory.py index 890bf00..6aa9f1e 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -3,6 +3,8 @@ from __future__ import annotations import json +import logging +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any @@ -16,6 +18,8 @@ Memory, MemoryEntry, MemoryStore, + _score_entry, + _simple_similarity, format_entry, ) @@ -24,12 +28,22 @@ class TestMemoryEntry: def test_round_trip(self) -> None: - entry = MemoryEntry(key='k', content='v', tags=['a', 'b'], created_at='t1', updated_at='t2') + entry = MemoryEntry( + key='k', + content='v', + tags=['a', 'b'], + scope='project', + expires_at='2099-01-01T00:00:00+00:00', + created_at='t1', + updated_at='t2', + ) assert MemoryEntry.from_dict(entry.to_dict()) == entry def test_from_dict_defaults(self) -> None: entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) assert entry.tags == [] + assert entry.scope == 'global' + assert entry.expires_at is None assert entry.created_at == '' assert entry.updated_at == '' @@ -38,6 +52,88 @@ def test_default_timestamps(self) -> None: assert entry.created_at # non-empty ISO string assert entry.updated_at + def test_default_scope(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.scope == 'global' + + def test_is_expired_no_expiry(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert not entry.is_expired() + + def test_is_expired_future(self) -> None: + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=future) + assert not entry.is_expired() + + def test_is_expired_past(self) -> None: + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=past) + assert entry.is_expired() + + +# --- _score_entry --- + + +class TestScoreEntry: + def test_no_match(self) -> None: + entry = MemoryEntry(key='greeting', content='hello world') + assert _score_entry(entry, ['zzz']) == 0 + + def test_key_match(self) -> None: + entry = MemoryEntry(key='greeting', content='some text') + assert _score_entry(entry, ['greeting']) == 1 + + def test_content_match(self) -> None: + entry = MemoryEntry(key='k', content='hello world') + assert _score_entry(entry, ['hello']) == 1 + + def test_tag_match(self) -> None: + entry = MemoryEntry(key='k', content='text', tags=['important']) + assert _score_entry(entry, ['important']) == 1 + + def test_multiple_field_match(self) -> None: + entry = MemoryEntry(key='hello', content='hello world', tags=['hello']) + # 'hello' appears in key (1) + content (1) + tags (1) = 3 + assert _score_entry(entry, ['hello']) == 3 + + def test_multiple_words(self) -> None: + entry = MemoryEntry(key='user', content='Alice likes blue') + # 'alice' in content (1), 'blue' in content (1) = 2 + assert _score_entry(entry, ['alice', 'blue']) == 2 + + def test_word_boundary_no_partial(self) -> None: + # 'fox' should NOT match 'foxes' with word-boundary matching + entry = MemoryEntry(key='k', content='foxes jump') + assert _score_entry(entry, ['fox']) == 0 + + +# --- _simple_similarity --- + + +class TestSimpleSimilarity: + def test_identical_keys_not_similar(self) -> None: + assert not _simple_similarity('abcdefghij', 'abcdefghij') + + def test_short_keys_not_similar(self) -> None: + assert not _simple_similarity('abc', 'abd') + + def test_similar_long_keys(self) -> None: + # Differ by 2 characters ('fo' vs 'ba') — within the edit-distance threshold + assert _simple_similarity('abcdefghij_fo', 'abcdefghij_ba') + + def test_different_prefix(self) -> None: + assert not _simple_similarity('xxxxxxxxxxfoo', 'yyyyyyyyyyfoo') + + def test_same_prefix_large_edit(self) -> None: + assert not _simple_similarity('abcdefghijklmnop', 'abcdefghijzzzzzz') + + def test_length_diff_too_large(self) -> None: + # Same 10-char prefix but length differs by more than 2 + assert not _simple_similarity('abcdefghij_x', 'abcdefghij_xyzw') + + def test_one_char_diff(self) -> None: + assert _simple_similarity('abcdefghij_x', 'abcdefghij_y') + # --- InMemoryStore --- @@ -83,6 +179,29 @@ def test_list_all(self) -> None: assert len(entries) == 2 assert {e.key for e in entries} == {'a', 'b'} + def test_list_all_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='fresh')) + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + entries = store.list_all() + assert len(entries) == 1 + assert entries[0].key == 'alive' + + def test_list_all_scope_filter(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + entries = store.list_all(scope='project') + assert len(entries) == 1 + assert entries[0].key == 'a' + + def test_list_all_scope_none_returns_all(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + assert len(store.list_all(scope=None)) == 2 + def test_search_by_key(self) -> None: store = InMemoryStore() store.put(MemoryEntry(key='user_name', content='Alice')) @@ -118,6 +237,39 @@ def test_search_no_results(self) -> None: store.put(MemoryEntry(key='k', content='v')) assert store.search('zzz') == [] + def test_search_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + results = store.search('hello') + assert len(results) == 1 + assert results[0].key == 'alive' + + def test_search_scope_filter(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', scope='project')) + store.put(MemoryEntry(key='b', content='hello world', scope='global')) + results = store.search('hello', scope='project') + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_relevance_ordering(self) -> None: + store = InMemoryStore() + # 'hello' appears in key + content = score 2 + store.put(MemoryEntry(key='hello', content='hello there')) + # 'hello' appears only in content = score 1 + store.put(MemoryEntry(key='other', content='hello world')) + results = store.search('hello') + assert len(results) == 2 + assert results[0].key == 'hello' # higher score first + assert results[1].key == 'other' + + def test_search_empty_query(self) -> None: + store = InMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + # --- FileStore --- @@ -151,6 +303,11 @@ def test_delete_saves(self, tmp_path: Path) -> None: store2 = FileStore(path) assert store2.get('k') is None + def test_delete_missing(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + assert store.delete('nope') is False + def test_list_all(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store = FileStore(path) @@ -188,11 +345,68 @@ def test_file_format(self, tmp_path: Path) -> None: 'key': 'k', 'content': 'v', 'tags': ['t'], + 'scope': 'global', + 'expires_at': None, 'created_at': 'c', 'updated_at': 'u', } } + def test_list_all_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='x')) + store.put(MemoryEntry(key='dead', content='y', expires_at=past)) + assert len(store.list_all()) == 1 + + def test_search_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + assert len(store.search('hello')) == 1 + + def test_list_all_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='x', scope='project')) + store.put(MemoryEntry(key='b', content='y', scope='global')) + assert len(store.list_all(scope='project')) == 1 + + def test_search_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='a', content='hello world', scope='project')) + store.put(MemoryEntry(key='b', content='hello world', scope='global')) + assert len(store.search('hello', scope='project')) == 1 + + def test_search_empty_query(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + + def test_scope_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='v', scope='session')) + store2 = FileStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.scope == 'session' + + def test_expires_at_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + store1 = FileStore(path) + store1.put(MemoryEntry(key='k', content='v', expires_at=future)) + store2 = FileStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.expires_at == future + # --- format_entry --- @@ -206,6 +420,28 @@ def test_with_tags(self) -> None: entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) assert format_entry(entry) == '[k] hello (tags: a, b)' + def test_with_scope(self) -> None: + entry = MemoryEntry(key='k', content='hello', scope='project') + assert format_entry(entry) == '[k] hello (scope: project)' + + def test_global_scope_omitted(self) -> None: + entry = MemoryEntry(key='k', content='hello', scope='global') + assert format_entry(entry) == '[k] hello' + + def test_with_expires_at(self) -> None: + entry = MemoryEntry(key='k', content='hello', expires_at='2099-01-01T00:00:00+00:00') + assert format_entry(entry) == '[k] hello (expires: 2099-01-01T00:00:00+00:00)' + + def test_all_extras(self) -> None: + entry = MemoryEntry( + key='k', + content='hello', + tags=['t'], + scope='project', + expires_at='2099-01-01T00:00:00+00:00', + ) + assert format_entry(entry) == '[k] hello (tags: t; scope: project; expires: 2099-01-01T00:00:00+00:00)' + # --- Memory capability --- @@ -266,6 +502,13 @@ def test_recall_missing(self) -> None: tools = self._get_tools() assert 'No memory found' in tools['recall_memory']('nope') + def test_recall_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='old', content='stale', expires_at=past)) + tools = self._get_tools(store) + assert 'No memory found' in tools['recall_memory']('old') + def test_save_updates_existing(self) -> None: store = InMemoryStore() tools = self._get_tools(store) @@ -289,13 +532,33 @@ def test_save_with_tags(self) -> None: assert entry is not None assert entry.tags == ['tag1', 'tag2'] + def test_save_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'project') + entry = store.get('k') + assert entry is not None + assert entry.scope == 'project' + + def test_save_with_ttl(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'global', 60) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + expires = datetime.fromisoformat(entry.expires_at) + # Should expire roughly 60 minutes from now + assert expires > datetime.now(timezone.utc) + timedelta(minutes=59) + assert expires < datetime.now(timezone.utc) + timedelta(minutes=61) + def test_search(self) -> None: store = InMemoryStore() tools = self._get_tools(store) tools['save_memory']('user_name', 'Alice') tools['save_memory']('color', 'blue') - result = tools['search_memories']('alice') + result = tools['search_memories']('Alice') assert 'Alice' in result assert 'blue' not in result @@ -303,6 +566,15 @@ def test_search_no_results(self) -> None: tools = self._get_tools() assert 'No memories found' in tools['search_memories']('zzz') + def test_search_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'hello world', None, 'project') + tools['save_memory']('b', 'hello world', None, 'global') + result = tools['search_memories']('hello', 'project') + assert '[a]' in result + assert '[b]' not in result + def test_list_empty(self) -> None: tools = self._get_tools() assert tools['list_memories']() == 'No memories stored.' @@ -316,6 +588,15 @@ def test_list_with_entries(self) -> None: assert '[a] alpha' in result assert '[b] beta' in result + def test_list_with_scope(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha', None, 'project') + tools['save_memory']('b', 'beta', None, 'global') + result = tools['list_memories']('project') + assert '[a] alpha' in result + assert '[b]' not in result + def test_delete_existing(self) -> None: store = InMemoryStore() tools = self._get_tools(store) @@ -328,6 +609,35 @@ def test_delete_missing(self) -> None: assert 'No memory found' in tools['delete_memory']('nope') +# --- Dedup warning --- + + +class TestDedupWarning: + def test_similar_key_logs_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abcdefghij_x', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('abcdefghij_y', 'second value') + assert any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_different_keys_no_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('first_key_long', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('other_key_long', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_short_keys_no_warning(self, caplog: Any) -> None: + store = InMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abc', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): + tools['save_memory']('abd', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + # --- Instructions --- From d9ce68835d1114464d432d534c070444d6ca0b8e Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 16:55:23 -0500 Subject: [PATCH 03/21] refactor(memory): add MemoryEntryDict TypedDict, eliminate avoidable Any types Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 47 ++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 5649265..ee651ba 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -13,7 +13,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, TypedDict, runtime_checkable from pydantic_ai._instructions import AgentInstructions from pydantic_ai.capabilities.abstract import AbstractCapability @@ -24,6 +24,27 @@ logger = logging.getLogger(__name__) +class _MemoryEntryDictRequired(TypedDict): + """Required fields for MemoryEntryDict.""" + + key: str + content: str + + +class MemoryEntryDict(_MemoryEntryDictRequired, total=False): + """Serialized form of a MemoryEntry for JSON storage. + + Only `key` and `content` are required; the remaining fields are + optional so that `from_dict` can accept legacy data missing some keys. + """ + + tags: list[str] + scope: str + expires_at: str | None + created_at: str + updated_at: str + + @dataclass class MemoryEntry: """A single memory entry with content, tags, and timestamps.""" @@ -34,14 +55,14 @@ class MemoryEntry: content: str """The content of the memory.""" - tags: list[str] = field(default_factory=list[str]) + tags: list[str] = field(default_factory=lambda: list[str]()) """Optional tags for categorization and search.""" scope: str = 'global' - """Namespace scope for this memory (default ``'global'``).""" + """Namespace scope for this memory (default `'global'`).""" expires_at: str | None = None - """Optional ISO 8601 expiration timestamp. ``None`` means no expiry.""" + """Optional ISO 8601 expiration timestamp. `None` means no expiry.""" created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of when the memory was first created.""" @@ -55,7 +76,7 @@ def is_expired(self) -> bool: return False return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MemoryEntryDict: """Serialize to a plain dict for JSON storage.""" return { 'key': self.key, @@ -68,7 +89,7 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> MemoryEntry: + def from_dict(cls, data: MemoryEntryDict) -> MemoryEntry: """Deserialize from a plain dict.""" return cls( key=data['key'], @@ -87,7 +108,7 @@ def _score_entry(entry: MemoryEntry, words: list[str]) -> int: Each query word that appears as a whole word (case-insensitive) in the key, content, or any tag contributes one point per field it appears in. Underscores and hyphens are treated as word separators in addition to - the standard ``\b`` boundaries. + the standard `\\b` boundaries. """ score = 0 for word in words: @@ -160,7 +181,7 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: class InMemoryStore: """Dict-based in-memory store, suitable for testing. - All data lives in a plain ``dict`` and is lost when the process exits. + All data lives in a plain `dict` and is lost when the process exits. """ def __init__(self) -> None: @@ -219,7 +240,7 @@ def __init__(self, path: str | Path) -> None: def _load(self) -> None: if self._path.exists(): - raw: dict[str, Any] = json.loads(self._path.read_text(encoding='utf-8')) + raw = json.loads(self._path.read_text(encoding='utf-8')) self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} def _save(self) -> None: @@ -318,7 +339,7 @@ def get_serialization_name(cls) -> str | None: def from_spec(cls, *args: Any, **kwargs: Any) -> Memory[Any]: """Create from spec arguments. - Supports `backend` kwarg: ``"memory"`` (default) or ``"file"`` (requires `path`). + Supports `backend` kwarg: `"memory"` (default) or `"file"` (requires `path`). """ backend = kwargs.pop('backend', 'memory') if backend == 'file': @@ -350,8 +371,8 @@ def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: def get_toolset(self) -> AgentToolset[AgentDepsT] | None: """Return a toolset with memory management tools. - Tool functions close over ``self`` to access the store without - requiring anything from the agent's ``deps``. + Tool functions close over `self` to access the store without + requiring anything from the agent's `deps`. """ store = self.store @@ -368,7 +389,7 @@ def save_memory( key: Unique key for this memory. content: The content to remember. tags: Optional tags for categorization and search. - scope: Namespace scope (default ``'global'``). + scope: Namespace scope (default `'global'`). ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. """ now = datetime.now(timezone.utc) From 63cd254c15ac424b174f3c1fab61724dd4b08958 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 16:58:11 -0500 Subject: [PATCH 04/21] refactor(memory): extract _BaseDictStore to deduplicate InMemoryStore and FileStore Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 57 ++++++++++------------------------ 1 file changed, 17 insertions(+), 40 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index ee651ba..44c5e6a 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -178,15 +178,10 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: ... -class InMemoryStore: - """Dict-based in-memory store, suitable for testing. - - All data lives in a plain `dict` and is lost when the process exits. - """ +class _BaseDictStore: + """Base class for dict-backed memory stores.""" - def __init__(self) -> None: - """Initialize an empty in-memory store.""" - self._entries: dict[str, MemoryEntry] = {} + _entries: dict[str, MemoryEntry] def get(self, key: str) -> MemoryEntry | None: """Retrieve a memory entry by key.""" @@ -226,7 +221,18 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: return [entry for _, entry in scored] -class FileStore: +class InMemoryStore(_BaseDictStore): + """Dict-based in-memory store, suitable for testing. + + All data lives in a plain `dict` and is lost when the process exits. + """ + + def __init__(self) -> None: + """Initialize an empty in-memory store.""" + self._entries: dict[str, MemoryEntry] = {} + + +class FileStore(_BaseDictStore): """JSON-file-based store for simple on-disk persistence. Reads the file on initialization and writes back on every mutation. @@ -248,47 +254,18 @@ def _save(self) -> None: data = {key: entry.to_dict() for key, entry in self._entries.items()} self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') - def get(self, key: str) -> MemoryEntry | None: - """Retrieve a memory entry by key.""" - return self._entries.get(key) - def put(self, entry: MemoryEntry) -> None: """Store or update a memory entry.""" - self._entries[entry.key] = entry + super().put(entry) self._save() def delete(self, key: str) -> bool: """Delete a memory entry by key.""" - existed = self._entries.pop(key, None) is not None + existed = super().delete(key) if existed: self._save() return existed - def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: - """Return all non-expired entries, optionally filtered by scope.""" - return [ - entry - for entry in self._entries.values() - if not entry.is_expired() and (scope is None or entry.scope == scope) - ] - - def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: - """Search non-expired entries with word-boundary matching, sorted by relevance.""" - words = query.lower().split() - if not words: - return [] - scored: list[tuple[int, MemoryEntry]] = [] - for entry in self._entries.values(): - if entry.is_expired(): - continue - if scope is not None and entry.scope != scope: - continue - score = _score_entry(entry, words) - if score > 0: - scored.append((score, entry)) - scored.sort(key=lambda pair: pair[0], reverse=True) - return [entry for _, entry in scored] - def format_entry(entry: MemoryEntry) -> str: """Format a memory entry as a human-readable string.""" From f9b10667cd20bea020370416bd097ba0e54b2ef6 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:00:03 -0500 Subject: [PATCH 05/21] fix(memory): handle malformed JSON gracefully in FileStore._load Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 11 +++++++++-- tests/test_memory.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 44c5e6a..cd9997d 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -246,8 +246,15 @@ def __init__(self, path: str | Path) -> None: def _load(self) -> None: if self._path.exists(): - raw = json.loads(self._path.read_text(encoding='utf-8')) - self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + try: + raw: dict[str, MemoryEntryDict] = json.loads(self._path.read_text(encoding='utf-8')) + if not isinstance(raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] + logger.warning('Memory file %s contains non-dict JSON, starting empty', self._path) + return + self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning('Failed to load memory file %s: %s, starting empty', self._path, e) + self._entries = {} def _save(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_memory.py b/tests/test_memory.py index 6aa9f1e..8d944f1 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -388,6 +388,24 @@ def test_search_empty_query(self, tmp_path: Path) -> None: store.put(MemoryEntry(key='k', content='v')) assert store.search('') == [] + def test_load_malformed_json(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('not json at all', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_load_wrong_structure(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('["a", "b"]', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + + def test_load_missing_entry_fields(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('{"k": {"not_a_key": "oops"}}', encoding='utf-8') + store = FileStore(path) + assert store.list_all() == [] + def test_scope_persists(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store1 = FileStore(path) From 7ddf098ffcc22b2fa79655f4296dd8ae24dcb98a Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:01:36 -0500 Subject: [PATCH 06/21] refactor(memory): make from_spec signature explicit, raise on unknown backend Co-Authored-By: Claude Opus 4.6 (1M context) --- src/pydantic_harness/memory.py | 32 +++++++++++++++++++++++++------- tests/test_memory.py | 21 +++++++++++++++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index cd9997d..6dce2e6 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -320,16 +320,34 @@ def get_serialization_name(cls) -> str | None: return 'Memory' @classmethod - def from_spec(cls, *args: Any, **kwargs: Any) -> Memory[Any]: + def from_spec( + cls, + *, + backend: str = 'memory', + path: str = '.memories.json', + inject_memories_in_instructions: bool = True, + max_instructions_memories: int = 20, + ) -> Memory[Any]: """Create from spec arguments. - Supports `backend` kwarg: `"memory"` (default) or `"file"` (requires `path`). + Args: + backend: Storage backend, `"memory"` (default) or `"file"`. + path: File path for the `"file"` backend (default `".memories.json"`). + inject_memories_in_instructions: Whether to inject memories into the system prompt. + max_instructions_memories: Maximum memories to inject into the system prompt. """ - backend = kwargs.pop('backend', 'memory') - if backend == 'file': - path = kwargs.pop('path', '.memories.json') - return cls(store=FileStore(path), **kwargs) - return cls(store=InMemoryStore(), **kwargs) + store: MemoryStore + if backend == 'memory': + store = InMemoryStore() + elif backend == 'file': + store = FileStore(path) + else: + raise ValueError(f'Unknown memory backend: {backend!r}. Use "memory" or "file".') + return cls( + store=store, + inject_memories_in_instructions=inject_memories_in_instructions, + max_instructions_memories=max_instructions_memories, + ) def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: """Build dynamic instructions that include currently stored memories.""" diff --git a/tests/test_memory.py b/tests/test_memory.py index 8d944f1..0057e5f 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -477,6 +477,27 @@ def test_from_spec_file(self, tmp_path: Path) -> None: cap = Memory.from_spec(backend='file', path=str(path)) assert isinstance(cap.store, FileStore) + def test_from_spec_unknown_backend(self) -> None: + import pytest + + with pytest.raises(ValueError, match='Unknown memory backend'): + Memory.from_spec(backend='redis') + + def test_from_spec_explicit_memory_backend(self) -> None: + cap = Memory.from_spec(backend='memory') + assert isinstance(cap.store, InMemoryStore) + + def test_from_spec_with_options(self, tmp_path: Path) -> None: + cap = Memory.from_spec( + backend='file', + path=str(tmp_path / 'mem.json'), + inject_memories_in_instructions=False, + max_instructions_memories=10, + ) + assert isinstance(cap.store, FileStore) + assert cap.inject_memories_in_instructions is False + assert cap.max_instructions_memories == 10 + def test_default_store(self) -> None: cap: Memory[None] = Memory() assert isinstance(cap.store, InMemoryStore) From 11e944cece812ec718d059f592b8747a6ea6c09c Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:04:14 -0500 Subject: [PATCH 07/21] test(memory): add edge case tests for scoring, similarity, format, TTL, and conformance Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_memory.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/tests/test_memory.py b/tests/test_memory.py index 0057e5f..82a45f7 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -106,6 +106,27 @@ def test_word_boundary_no_partial(self) -> None: entry = MemoryEntry(key='k', content='foxes jump') assert _score_entry(entry, ['fox']) == 0 + def test_regex_metacharacters_in_query(self) -> None: + entry = MemoryEntry(key='lang', content='I use c++ daily') + assert _score_entry(entry, ['c++']) == 1 + + def test_empty_words_list(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert _score_entry(entry, []) == 0 + + def test_underscore_word_boundary(self) -> None: + entry = MemoryEntry(key='user_name', content='text') + assert _score_entry(entry, ['name']) == 1 + + def test_hyphen_word_boundary(self) -> None: + entry = MemoryEntry(key='my-project', content='text') + assert _score_entry(entry, ['project']) == 1 + + def test_partial_word_match(self) -> None: + entry = MemoryEntry(key='k', content='alice likes blue') + # 'alice' matches (1), 'zzz' does not (0) = score 1 + assert _score_entry(entry, ['alice', 'zzz']) == 1 + # --- _simple_similarity --- @@ -134,6 +155,18 @@ def test_length_diff_too_large(self) -> None: def test_one_char_diff(self) -> None: assert _simple_similarity('abcdefghij_x', 'abcdefghij_y') + def test_edit_distance_exactly_three(self) -> None: + # Just over the threshold -- should NOT be similar + assert not _simple_similarity('abcdefghij_abc', 'abcdefghij_xyz') + + def test_nine_char_keys(self) -> None: + # Just below the 10-char minimum + assert not _simple_similarity('abcdefghi', 'abcdefghj') + + def test_exactly_ten_char_keys_not_similar(self) -> None: + # 10-char keys differing at position 10 do NOT share a 10-char prefix + assert not _simple_similarity('abcdefghij', 'abcdefghik') + # --- InMemoryStore --- @@ -460,6 +493,14 @@ def test_all_extras(self) -> None: ) assert format_entry(entry) == '[k] hello (tags: t; scope: project; expires: 2099-01-01T00:00:00+00:00)' + def test_empty_content(self) -> None: + entry = MemoryEntry(key='k', content='') + assert format_entry(entry) == '[k] ' + + def test_empty_key(self) -> None: + entry = MemoryEntry(key='', content='hello') + assert format_entry(entry) == '[] hello' + # --- Memory capability --- @@ -647,6 +688,16 @@ def test_delete_missing(self) -> None: tools = self._get_tools() assert 'No memory found' in tools['delete_memory']('nope') + def test_save_with_ttl_zero(self) -> None: + store = InMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, 'global', 0) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + # TTL=0 means it expires immediately + assert entry.is_expired() + # --- Dedup warning --- @@ -724,6 +775,16 @@ def test_instructions_disabled(self) -> None: text = cap.build_instructions(self._make_ctx()) assert 'Currently stored memories' not in text + def test_instructions_exact_max_no_overflow(self) -> None: + store = InMemoryStore() + for i in range(5): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and' not in text + assert '[k0]' in text + assert '[k4]' in text + # --- MemoryStore protocol --- @@ -734,3 +795,18 @@ def test_in_memory_store_satisfies_protocol(self) -> None: def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: assert isinstance(FileStore(tmp_path / 'mem.json'), MemoryStore) + + +# --- AbstractCapability conformance --- + + +class TestAbstractCapabilityConformance: + def test_is_abstract_capability_subclass(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert issubclass(Memory, AbstractCapability) + + def test_instance_is_abstract_capability(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert isinstance(Memory(), AbstractCapability) From c9dc52cfd3f2b428dd5ef78eb84ffb79c066eee2 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:05:26 -0500 Subject: [PATCH 08/21] chore(memory): update exports and plan to reflect review changes Co-Authored-By: Claude Opus 4.6 (1M context) --- PLAN.md | 11 +++++++---- src/pydantic_harness/__init__.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/PLAN.md b/PLAN.md index ca533d8..0f95390 100644 --- a/PLAN.md +++ b/PLAN.md @@ -21,8 +21,12 @@ Implements a `Memory` capability (`AbstractCapability` subclass) that provides p ### Memory Model -- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `created_at`, `updated_at` -- Search is substring-based (case-insensitive) across key, content, and tags +- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `scope`, `expires_at`, `created_at`, `updated_at` +- **`MemoryEntryDict`** TypedDict for serialization +- Word-boundary search with relevance scoring (case-insensitive) across key, content, and tags +- Scoping/namespaces via `scope` field with filtering on search/list +- TTL/expiration via `expires_at` with `is_expired()` auto-filtering +- Dedup warning on save when keys are similar (Levenshtein distance <= 2) ### Spec Serialization @@ -41,11 +45,10 @@ Implements a `Memory` capability (`AbstractCapability` subclass) that provides p - `src/pydantic_harness/memory.py` - Capability, stores, entry model - `src/pydantic_harness/__init__.py` - Re-exports -- `tests/test_memory.py` - 48 tests covering all code paths +- `tests/test_memory.py` - 113 tests covering all code paths ## Future Work - Semantic/vector search backend (e.g. embedding-based `MemoryStore`) -- TTL / expiration on entries - Session-scoped memory isolation via `for_run()` - SQLite / Redis backends for production persistence diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 62831af..d47b53e 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,12 +7,13 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryStore +from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryEntryDict, MemoryStore __all__: list[str] = [ 'FileStore', 'InMemoryStore', 'Memory', 'MemoryEntry', + 'MemoryEntryDict', 'MemoryStore', ] From aa80c70a2c3a0cd0570a047eddde0440440248bf Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:24:55 -0500 Subject: [PATCH 09/21] feat(memory): add 3 example scripts with logfire instrumentation - personal_assistant.py: FileStore persistence, preferences, instructions injection - study_coach.py: TTL/spaced repetition, tags, search - coding_assistant.py: procedural memory, rules, search, delete All examples assert on memory state and are instrumented with logfire spans. Co-Authored-By: Claude Opus 4.6 (1M context) --- examples/memory/coding_assistant.py | 101 ++++++++++++++++++++++++++ examples/memory/personal_assistant.py | 89 +++++++++++++++++++++++ examples/memory/study_coach.py | 76 +++++++++++++++++++ 3 files changed, 266 insertions(+) create mode 100644 examples/memory/coding_assistant.py create mode 100644 examples/memory/personal_assistant.py create mode 100644 examples/memory/study_coach.py diff --git a/examples/memory/coding_assistant.py b/examples/memory/coding_assistant.py new file mode 100644 index 0000000..65c4123 --- /dev/null +++ b/examples/memory/coding_assistant.py @@ -0,0 +1,101 @@ +"""Self-Improving Coding Assistant — procedural memory via instructions injection. + +Demonstrates: instructions injection as self-modifying prompt, scoping, search, delete. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import InMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the coding assistant example.""" + store = InMemoryStore() + memory = Memory(store=store, max_instructions_memories=10) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a coding assistant that learns from user corrections. ' + 'When the user gives you a coding rule or correction, save it as a memory ' + 'with scope "rules" and tags like ["python", "style"] or ["typescript", "testing"]. ' + 'Use descriptive keys like "rule_python_fstrings" or "rule_ts_const". ' + 'When asked to write code, search your memories for relevant rules first.' + ), + ) + + # --- Teach rules --- + with logfire.span('teach-rules'): + result1 = agent.run_sync( + 'Remember these coding rules:\n' + '1. Always use f-strings in Python, never .format() or % formatting\n' + '2. In TypeScript, prefer const over let, never use var\n' + '3. Always add type hints to Python function signatures' + ) + print(f'Assistant: {result1.output}') + + rules = store.list_all() + print(f'\nRules stored: {len(rules)}') + for r in rules: + print(f' [{r.key}] {r.content} (scope={r.scope}, tags={r.tags})') + + assert len(rules) >= 3, f'Expected at least 3 rules saved, got {len(rules)}' + + # Check that search works across stored rules + python_rules = store.search('python') + print(f'Rules matching "python": {len(python_rules)}') + assert len(python_rules) >= 1, 'Expected at least 1 rule matching "python"' + + # --- Verify instructions injection --- + # Build instructions should now include the rules + from unittest.mock import MagicMock + + from pydantic_ai._run_context import RunContext + from pydantic_ai.usage import RunUsage + + ctx: RunContext[None] = RunContext(deps=None, model=MagicMock(), usage=RunUsage()) + instructions = memory.build_instructions(ctx) + print(f'\nInstructions preview (first 300 chars):\n{instructions[:300]}...') + + assert 'Currently stored memories' in instructions, 'Expected memories in instructions' + + # --- Ask for code, verify rules are considered --- + with logfire.span('apply-rules'): + result2 = agent.run_sync( + 'Write a Python function that greets a user by name. Follow all coding rules you know.' + ) + print(f'\nAssistant: {result2.output}') + + # The output should use f-strings and type hints (based on rules) + output_lower = result2.output.lower() + assert "f'" in result2.output or 'f"' in result2.output or 'f-string' in output_lower, ( + 'Expected f-string usage in code output' + ) + + # --- Delete an obsolete rule --- + with logfire.span('delete-rule'): + result3 = agent.run_sync('Actually, the TypeScript const rule is outdated for this project. Delete it.') + print(f'\nAssistant: {result3.output}') + + remaining = store.list_all() + print(f'\nRules after deletion: {len(remaining)}') + for r in remaining: + print(f' [{r.key}] {r.content}') + + # Should have fewer rules now + assert len(remaining) < len(rules), 'Expected at least one rule deleted' + + print('\n--- Coding Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/personal_assistant.py b/examples/memory/personal_assistant.py new file mode 100644 index 0000000..cd477f1 --- /dev/null +++ b/examples/memory/personal_assistant.py @@ -0,0 +1,89 @@ +"""Personal Assistant — remembers user preferences across sessions. + +Demonstrates: FileStore persistence, save/recall, instructions injection, tags, scoping. +""" + +from __future__ import annotations + +import sys +import tempfile +from pathlib import Path + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import FileStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the personal assistant example.""" + with tempfile.TemporaryDirectory() as tmpdir: + mem_path = Path(tmpdir) / 'preferences.json' + store = FileStore(mem_path) + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a helpful personal assistant. ' + 'When the user tells you about their preferences, save each one as a memory ' + 'with scope "user_prefs" and appropriate tags. ' + 'Use descriptive keys like "preferred_name" or "theme_preference".' + ), + ) + + # --- Session 1: user shares preferences --- + with logfire.span('session-1-save-preferences'): + result1 = agent.run_sync("Hi! My name is Alice, I prefer dark mode, and I'm vegetarian.") + print(f'Assistant: {result1.output}') + + entries = store.list_all() + print(f'\nMemories after session 1: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, scope={e.scope})') + + assert len(entries) >= 2, f'Expected at least 2 memories saved, got {len(entries)}' + all_content = ' '.join(e.content.lower() for e in entries) + assert 'alice' in all_content or any('alice' in e.key.lower() for e in entries), 'Expected a memory about Alice' + + # --- Session 2: new agent instance loads from same file (persistence) --- + store2 = FileStore(mem_path) + memory2 = Memory(store=store2) + agent2 = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory2], + system_prompt='You are a helpful personal assistant.', + ) + + loaded_entries = store2.list_all() + print(f'\nMemories loaded in session 2: {len(loaded_entries)}') + assert len(loaded_entries) == len(entries), 'FileStore persistence failed' + + with logfire.span('session-2-recall-preferences'): + result2 = agent2.run_sync('What do you know about me?') + print(f'Assistant: {result2.output}') + + # The instructions injection should have included the memories + assert 'alice' in result2.output.lower() or 'dark' in result2.output.lower(), ( + 'Expected assistant to recall preferences from instructions injection' + ) + + # --- Session 3: update a preference --- + with logfire.span('session-3-update-preference'): + result3 = agent2.run_sync('Actually, I go by Ali now. Please update my name.') + print(f'\nAssistant: {result3.output}') + + updated_entries = store2.list_all() + print(f'\nMemories after update: {len(updated_entries)}') + for e in updated_entries: + print(f' [{e.key}] {e.content} (tags={e.tags})') + + print('\n--- Personal Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/study_coach.py b/examples/memory/study_coach.py new file mode 100644 index 0000000..702b335 --- /dev/null +++ b/examples/memory/study_coach.py @@ -0,0 +1,76 @@ +"""Study Coach — spaced repetition with TTL. + +Demonstrates: TTL/expiration, save with ttl_minutes, list/search, tags. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_harness.memory import InMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() + + +def main() -> None: + """Run the study coach example.""" + store = InMemoryStore() + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a study coach that helps users learn facts. ' + 'When the user provides a fact to learn, save it as a memory with ' + 'tag "study" and a ttl_minutes value: use 1 for new/hard facts, ' + '60 for reviewed facts, and 1440 for mastered facts. ' + 'Use descriptive keys like "biology_mitochondria" or "history_magna_carta".' + ), + ) + + # --- Learn some facts --- + with logfire.span('learn-facts'): + result1 = agent.run_sync( + 'I need to learn these facts:\n' + '1. Mitochondria are the powerhouse of the cell\n' + '2. The Magna Carta was signed in 1215\n' + '3. Water boils at 100 degrees Celsius at sea level' + ) + print(f'Coach: {result1.output}') + + entries = store.list_all() + print(f'\nFacts stored: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, ttl={e.expires_at})') + + assert len(entries) >= 3, f'Expected at least 3 facts saved, got {len(entries)}' + + # Check that TTL was set on at least some entries + entries_with_ttl = [e for e in entries if e.expires_at is not None] + assert len(entries_with_ttl) >= 1, 'Expected at least 1 entry with TTL set' + print(f'Entries with TTL: {len(entries_with_ttl)}') + + # Check tags + entries_with_study_tag = [e for e in entries if 'study' in e.tags] + assert len(entries_with_study_tag) >= 1, 'Expected at least 1 entry with "study" tag' + + # --- Search for facts --- + with logfire.span('search-facts'): + result2 = agent.run_sync('Search my memories for anything about biology.') + print(f'\nCoach: {result2.output}') + + # --- List all facts --- + with logfire.span('list-facts'): + result3 = agent.run_sync('List all my study memories.') + print(f'\nCoach: {result3.output}') + + print('\n--- Study Coach example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) From 58c70a716a74434742c216b7420e6eb4176d9a2f Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Sat, 4 Apr 2026 11:27:47 -0500 Subject: [PATCH 10/21] chore: remove settings.local.json from tracking, restore original deps Co-Authored-By: Claude Opus 4.6 (1M context) --- .agents/settings.local.json | 5 ----- .gitignore | 5 +++++ uv.lock | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) delete mode 100644 .agents/settings.local.json diff --git a/.agents/settings.local.json b/.agents/settings.local.json deleted file mode 100644 index 8b311a3..0000000 --- a/.agents/settings.local.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "permissions": { - "allow": [] - } -} diff --git a/.gitignore b/.gitignore index 00e73a6..36c255e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +.env* +.mcp.json +.DS_Store +.agents/settings.local.json + # IDE .idea/ diff --git a/uv.lock b/uv.lock index 0730281..6178ac2 100644 --- a/uv.lock +++ b/uv.lock @@ -319,15 +319,15 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.40.0" +version = "1.39.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "importlib-metadata" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" } +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" }, + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, ] [[package]] From 864cc1fc28a4817b1cca9474b4068cea852258eb Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:23:51 -0500 Subject: [PATCH 11/21] chore: exclude examples/ from pyright strict typecheck MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The example scripts use logfire (not in any dep group) and other runtime-only imports. Strict pyright on them blocks pre-commit hooks without adding value — examples are illustrative, not part of the typed library surface. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0d573a0..8c5e481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ quote-style = 'single' [tool.pyright] pythonVersion = '3.10' typeCheckingMode = 'strict' +exclude = ['examples', '.venv', '**/node_modules', '**/__pycache__'] executionEnvironments = [ { root = 'tests', reportPrivateUsage = false }, ] From 2edaf875934c176375d8c34a7fd67225a5a8c94f Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:23:59 -0500 Subject: [PATCH 12/21] fix(memory): drop expired entries on read and persist _BaseDictStore.get() now filters expired entries (consistency with list_all/search) and FileStore._save() drops expired entries before writing, so long-running file-backed agents no longer accumulate dead records on disk. Tightens is_expired docstring with wall-clock semantics. TTL=0 now correctly results in an immediately-invisible entry (test updated accordingly). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/pydantic_harness/memory.py | 25 ++++++++++++++++++----- tests/test_memory.py | 36 +++++++++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 6dce2e6..013b2a7 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -71,7 +71,14 @@ class MemoryEntry: """ISO 8601 timestamp of the last update.""" def is_expired(self) -> bool: - """Return True if this entry has passed its expiration time.""" + """Return True if this entry has passed its expiration time. + + Wall-clock semantics: an entry created with `ttl_minutes=N` expires `N` + minutes after creation in real time, regardless of how many agent turns + or sessions have elapsed. TTL is opt-in (default `expires_at=None` = + no expiry) and intended for facts with a real-world lifetime + (verification codes, session credentials, etc.). + """ if self.expires_at is None: return False return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) @@ -184,8 +191,11 @@ class _BaseDictStore: _entries: dict[str, MemoryEntry] def get(self, key: str) -> MemoryEntry | None: - """Retrieve a memory entry by key.""" - return self._entries.get(key) + """Retrieve a non-expired memory entry by key.""" + entry = self._entries.get(key) + if entry is None or entry.is_expired(): + return None + return entry def put(self, entry: MemoryEntry) -> None: """Store or update a memory entry.""" @@ -195,6 +205,12 @@ def delete(self, key: str) -> bool: """Delete a memory entry by key.""" return self._entries.pop(key, None) is not None + def _gc_expired(self) -> None: + """Drop expired entries from the backing dict.""" + expired_keys = [key for key, entry in self._entries.items() if entry.is_expired()] + for key in expired_keys: + del self._entries[key] + def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: """Return all non-expired entries, optionally filtered by scope.""" return [ @@ -258,6 +274,7 @@ def _load(self) -> None: def _save(self) -> None: self._path.parent.mkdir(parents=True, exist_ok=True) + self._gc_expired() data = {key: entry.to_dict() for key, entry in self._entries.items()} self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') @@ -432,8 +449,6 @@ def recall_memory(key: str) -> str: entry = store.get(key) if entry is None: return f'No memory found for key: {key}' - if entry.is_expired(): - return f'No memory found for key: {key}' return format_entry(entry) def search_memories(query: str, scope: str | None = None) -> str: diff --git a/tests/test_memory.py b/tests/test_memory.py index 82a45f7..321502e 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -221,6 +221,12 @@ def test_list_all_filters_expired(self) -> None: assert len(entries) == 1 assert entries[0].key == 'alive' + def test_get_filters_expired(self) -> None: + store = InMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + assert store.get('dead') is None + def test_list_all_scope_filter(self) -> None: store = InMemoryStore() store.put(MemoryEntry(key='a', content='x', scope='project')) @@ -458,6 +464,27 @@ def test_expires_at_persists(self, tmp_path: Path) -> None: assert entry is not None assert entry.expires_at == future + def test_save_drops_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + store.put(MemoryEntry(key='alive', content='fresh')) + + raw = json.loads(path.read_text(encoding='utf-8')) + assert 'dead' not in raw + assert 'alive' in raw + + def test_get_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileStore(path) + future = (datetime.now(timezone.utc) + timedelta(seconds=1)).isoformat() + store.put(MemoryEntry(key='soon', content='v', expires_at=future)) + # Manually backdate by mutating the in-memory entry's expires_at + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store._entries['soon'].expires_at = past + assert store.get('soon') is None + # --- format_entry --- @@ -692,11 +719,10 @@ def test_save_with_ttl_zero(self) -> None: store = InMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v', None, 'global', 0) - entry = store.get('k') - assert entry is not None - assert entry.expires_at is not None - # TTL=0 means it expires immediately - assert entry.is_expired() + # TTL=0 expires immediately; get() filters it out + assert store.get('k') is None + # recall_memory should likewise report no memory + assert 'No memory found' in tools['recall_memory']('k') # --- Dedup warning --- From e5354117ac463b0f64282c87f54a7bc0c369836e Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:26:21 -0500 Subject: [PATCH 13/21] =?UTF-8?q?refactor(memory):=20rename=20InMemoryStor?= =?UTF-8?q?e=E2=86=92DictMemoryStore,=20FileStore=E2=86=92FileMemoryStore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoids collision with PR #176's SessionPersistence (InMemorySessionStore / FileSessionStore) and pydantic-ai's binary file-store branch. The MemoryStore convention makes the domain explicit at import sites — `DictMemoryStore` describes what backs the store (`dict`), `FileMemoryStore` describes its persistence target. No behavior change. Co-Authored-By: Claude Opus 4.7 (1M context) --- PLAN.md | 8 +- examples/memory/coding_assistant.py | 4 +- examples/memory/personal_assistant.py | 10 +- examples/memory/study_coach.py | 4 +- src/pydantic_harness/__init__.py | 6 +- src/pydantic_harness/memory.py | 20 ++-- tests/test_memory.py | 158 +++++++++++++------------- 7 files changed, 105 insertions(+), 105 deletions(-) diff --git a/PLAN.md b/PLAN.md index 0f95390..abb9074 100644 --- a/PLAN.md +++ b/PLAN.md @@ -16,8 +16,8 @@ Implements a `Memory` capability (`AbstractCapability` subclass) that provides p ### Storage - **`MemoryStore`** protocol: pluggable backend with `get`, `put`, `delete`, `list_all`, `search` -- **`InMemoryStore`**: dict-based, ephemeral, for testing (default) -- **`FileStore`**: JSON file on disk, reads on init, writes on every mutation +- **`DictMemoryStore`**: dict-based, ephemeral, for testing (default) +- **`FileMemoryStore`**: JSON file on disk, reads on init, writes on every mutation ### Memory Model @@ -31,13 +31,13 @@ Implements a `Memory` capability (`AbstractCapability` subclass) that provides p ### Spec Serialization - `Memory.get_serialization_name()` returns `"Memory"` -- `Memory.from_spec(backend="file", path="...")` creates a `FileStore`-backed instance +- `Memory.from_spec(backend="file", path="...")` creates a `FileMemoryStore`-backed instance ## Configuration | Field | Default | Description | |-------|---------|-------------| -| `store` | `InMemoryStore()` | Storage backend | +| `store` | `DictMemoryStore()` | Storage backend | | `inject_memories_in_instructions` | `True` | Include memories in system prompt | | `max_instructions_memories` | `20` | Cap on memories injected into prompt | diff --git a/examples/memory/coding_assistant.py b/examples/memory/coding_assistant.py index 65c4123..36a9c91 100644 --- a/examples/memory/coding_assistant.py +++ b/examples/memory/coding_assistant.py @@ -10,7 +10,7 @@ import logfire from pydantic_ai import Agent -from pydantic_harness.memory import InMemoryStore, Memory +from pydantic_harness.memory import DictMemoryStore, Memory logfire.configure(send_to_logfire='if-token-present') logfire.instrument_openai() @@ -18,7 +18,7 @@ def main() -> None: """Run the coding assistant example.""" - store = InMemoryStore() + store = DictMemoryStore() memory = Memory(store=store, max_instructions_memories=10) agent = Agent( diff --git a/examples/memory/personal_assistant.py b/examples/memory/personal_assistant.py index cd477f1..c3c266a 100644 --- a/examples/memory/personal_assistant.py +++ b/examples/memory/personal_assistant.py @@ -1,6 +1,6 @@ """Personal Assistant — remembers user preferences across sessions. -Demonstrates: FileStore persistence, save/recall, instructions injection, tags, scoping. +Demonstrates: FileMemoryStore persistence, save/recall, instructions injection, tags, scoping. """ from __future__ import annotations @@ -12,7 +12,7 @@ import logfire from pydantic_ai import Agent -from pydantic_harness.memory import FileStore, Memory +from pydantic_harness.memory import FileMemoryStore, Memory logfire.configure(send_to_logfire='if-token-present') logfire.instrument_openai() @@ -22,7 +22,7 @@ def main() -> None: """Run the personal assistant example.""" with tempfile.TemporaryDirectory() as tmpdir: mem_path = Path(tmpdir) / 'preferences.json' - store = FileStore(mem_path) + store = FileMemoryStore(mem_path) memory = Memory(store=store) agent = Agent( @@ -51,7 +51,7 @@ def main() -> None: assert 'alice' in all_content or any('alice' in e.key.lower() for e in entries), 'Expected a memory about Alice' # --- Session 2: new agent instance loads from same file (persistence) --- - store2 = FileStore(mem_path) + store2 = FileMemoryStore(mem_path) memory2 = Memory(store=store2) agent2 = Agent( 'openai:gpt-4o-mini', @@ -61,7 +61,7 @@ def main() -> None: loaded_entries = store2.list_all() print(f'\nMemories loaded in session 2: {len(loaded_entries)}') - assert len(loaded_entries) == len(entries), 'FileStore persistence failed' + assert len(loaded_entries) == len(entries), 'FileMemoryStore persistence failed' with logfire.span('session-2-recall-preferences'): result2 = agent2.run_sync('What do you know about me?') diff --git a/examples/memory/study_coach.py b/examples/memory/study_coach.py index 702b335..324a2a5 100644 --- a/examples/memory/study_coach.py +++ b/examples/memory/study_coach.py @@ -10,7 +10,7 @@ import logfire from pydantic_ai import Agent -from pydantic_harness.memory import InMemoryStore, Memory +from pydantic_harness.memory import DictMemoryStore, Memory logfire.configure(send_to_logfire='if-token-present') logfire.instrument_openai() @@ -18,7 +18,7 @@ def main() -> None: """Run the study coach example.""" - store = InMemoryStore() + store = DictMemoryStore() memory = Memory(store=store) agent = Agent( diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index d47b53e..6b699d8 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,11 +7,11 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -from pydantic_harness.memory import FileStore, InMemoryStore, Memory, MemoryEntry, MemoryEntryDict, MemoryStore +from pydantic_harness.memory import DictMemoryStore, FileMemoryStore, Memory, MemoryEntry, MemoryEntryDict, MemoryStore __all__: list[str] = [ - 'FileStore', - 'InMemoryStore', + 'DictMemoryStore', + 'FileMemoryStore', 'Memory', 'MemoryEntry', 'MemoryEntryDict', diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 013b2a7..d640cff 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -1,8 +1,8 @@ """Memory capability for persistent agent memory across sessions. Provides tools for saving, recalling, searching, listing, and deleting -key-value memories, with pluggable storage backends (`InMemoryStore` for -testing, `FileStore` for on-disk persistence). +key-value memories, with pluggable storage backends (`DictMemoryStore` for +testing, `FileMemoryStore` for on-disk persistence). """ from __future__ import annotations @@ -237,7 +237,7 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: return [entry for _, entry in scored] -class InMemoryStore(_BaseDictStore): +class DictMemoryStore(_BaseDictStore): """Dict-based in-memory store, suitable for testing. All data lives in a plain `dict` and is lost when the process exits. @@ -248,7 +248,7 @@ def __init__(self) -> None: self._entries: dict[str, MemoryEntry] = {} -class FileStore(_BaseDictStore): +class FileMemoryStore(_BaseDictStore): """JSON-file-based store for simple on-disk persistence. Reads the file on initialization and writes back on every mutation. @@ -316,14 +316,14 @@ class Memory(AbstractCapability[AgentDepsT]): Example: ```python {test="skip" lint="skip"} from pydantic_ai import Agent - from pydantic_harness.memory import Memory, InMemoryStore + from pydantic_harness.memory import Memory, DictMemoryStore - agent = Agent('openai:gpt-4o', capabilities=[Memory(store=InMemoryStore())]) + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=DictMemoryStore())]) ``` """ - store: MemoryStore = field(default_factory=InMemoryStore) - """The storage backend. Defaults to `InMemoryStore` (ephemeral, dict-based).""" + store: MemoryStore = field(default_factory=DictMemoryStore) + """The storage backend. Defaults to `DictMemoryStore` (ephemeral, dict-based).""" inject_memories_in_instructions: bool = True """Whether to inject existing memories into the system prompt at run start.""" @@ -355,9 +355,9 @@ def from_spec( """ store: MemoryStore if backend == 'memory': - store = InMemoryStore() + store = DictMemoryStore() elif backend == 'file': - store = FileStore(path) + store = FileMemoryStore(path) else: raise ValueError(f'Unknown memory backend: {backend!r}. Use "memory" or "file".') return cls( diff --git a/tests/test_memory.py b/tests/test_memory.py index 321502e..62ac8e2 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -13,8 +13,8 @@ from pydantic_ai.usage import RunUsage from pydantic_harness.memory import ( - FileStore, - InMemoryStore, + DictMemoryStore, + FileMemoryStore, Memory, MemoryEntry, MemoryStore, @@ -168,22 +168,22 @@ def test_exactly_ten_char_keys_not_similar(self) -> None: assert not _simple_similarity('abcdefghij', 'abcdefghik') -# --- InMemoryStore --- +# --- DictMemoryStore --- -class TestInMemoryStore: +class TestDictMemoryStore: def test_put_and_get(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() entry = MemoryEntry(key='greeting', content='hello') store.put(entry) assert store.get('greeting') is entry def test_get_missing(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() assert store.get('nope') is None def test_put_overwrites(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k', content='v1')) store.put(MemoryEntry(key='k', content='v2')) result = store.get('k') @@ -191,21 +191,21 @@ def test_put_overwrites(self) -> None: assert result.content == 'v2' def test_delete_existing(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k', content='v')) assert store.delete('k') is True assert store.get('k') is None def test_delete_missing(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() assert store.delete('nope') is False def test_list_all_empty(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() assert store.list_all() == [] def test_list_all(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='a', content='alpha')) store.put(MemoryEntry(key='b', content='beta')) entries = store.list_all() @@ -213,7 +213,7 @@ def test_list_all(self) -> None: assert {e.key for e in entries} == {'a', 'b'} def test_list_all_filters_expired(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='alive', content='fresh')) store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) @@ -222,13 +222,13 @@ def test_list_all_filters_expired(self) -> None: assert entries[0].key == 'alive' def test_get_filters_expired(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) assert store.get('dead') is None def test_list_all_scope_filter(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='a', content='x', scope='project')) store.put(MemoryEntry(key='b', content='y', scope='global')) entries = store.list_all(scope='project') @@ -236,13 +236,13 @@ def test_list_all_scope_filter(self) -> None: assert entries[0].key == 'a' def test_list_all_scope_none_returns_all(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='a', content='x', scope='project')) store.put(MemoryEntry(key='b', content='y', scope='global')) assert len(store.list_all(scope=None)) == 2 def test_search_by_key(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='user_name', content='Alice')) store.put(MemoryEntry(key='color', content='blue')) results = store.search('user') @@ -250,7 +250,7 @@ def test_search_by_key(self) -> None: assert results[0].key == 'user_name' def test_search_by_content(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k1', content='the quick brown fox')) store.put(MemoryEntry(key='k2', content='lazy dog')) results = store.search('fox') @@ -258,7 +258,7 @@ def test_search_by_content(self) -> None: assert results[0].key == 'k1' def test_search_by_tag(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k1', content='x', tags=['important'])) store.put(MemoryEntry(key='k2', content='y', tags=['trivial'])) results = store.search('important') @@ -266,18 +266,18 @@ def test_search_by_tag(self) -> None: assert results[0].key == 'k1' def test_search_case_insensitive(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='K1', content='Hello World')) results = store.search('hello') assert len(results) == 1 def test_search_no_results(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k', content='v')) assert store.search('zzz') == [] def test_search_filters_expired(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='alive', content='hello world')) store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) @@ -286,7 +286,7 @@ def test_search_filters_expired(self) -> None: assert results[0].key == 'alive' def test_search_scope_filter(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='a', content='hello world', scope='project')) store.put(MemoryEntry(key='b', content='hello world', scope='global')) results = store.search('hello', scope='project') @@ -294,7 +294,7 @@ def test_search_scope_filter(self) -> None: assert results[0].key == 'a' def test_search_relevance_ordering(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() # 'hello' appears in key + content = score 2 store.put(MemoryEntry(key='hello', content='hello there')) # 'hello' appears only in content = score 1 @@ -305,58 +305,58 @@ def test_search_relevance_ordering(self) -> None: assert results[1].key == 'other' def test_search_empty_query(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k', content='v')) assert store.search('') == [] -# --- FileStore --- +# --- FileMemoryStore --- -class TestFileStore: +class TestFileMemoryStore: def test_put_and_get(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='k', content='v')) assert store.get('k') is not None assert store.get('k').content == 'v' # type: ignore[union-attr] def test_persistence(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store1 = FileStore(path) + store1 = FileMemoryStore(path) store1.put(MemoryEntry(key='k', content='persisted')) # New store instance should load from disk - store2 = FileStore(path) + store2 = FileMemoryStore(path) result = store2.get('k') assert result is not None assert result.content == 'persisted' def test_delete_saves(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='k', content='v')) store.delete('k') # Reload and verify deletion persisted - store2 = FileStore(path) + store2 = FileMemoryStore(path) assert store2.get('k') is None def test_delete_missing(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) assert store.delete('nope') is False def test_list_all(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='a', content='alpha')) store.put(MemoryEntry(key='b', content='beta')) assert len(store.list_all()) == 2 def test_search(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='k1', content='hello', tags=['greeting'])) store.put(MemoryEntry(key='k2', content='world')) assert len(store.search('greeting')) == 1 @@ -365,18 +365,18 @@ def test_search(self, tmp_path: Path) -> None: def test_empty_file(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' # File does not exist yet - store = FileStore(path) + store = FileMemoryStore(path) assert store.list_all() == [] def test_creates_parent_dirs(self, tmp_path: Path) -> None: path = tmp_path / 'sub' / 'dir' / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='k', content='v')) assert path.exists() def test_file_format(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='k', content='v', tags=['t'], created_at='c', updated_at='u')) raw = json.loads(path.read_text()) assert raw == { @@ -393,7 +393,7 @@ def test_file_format(self, tmp_path: Path) -> None: def test_list_all_filters_expired(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='alive', content='x')) store.put(MemoryEntry(key='dead', content='y', expires_at=past)) @@ -401,7 +401,7 @@ def test_list_all_filters_expired(self, tmp_path: Path) -> None: def test_search_filters_expired(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='alive', content='hello world')) store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) @@ -409,47 +409,47 @@ def test_search_filters_expired(self, tmp_path: Path) -> None: def test_list_all_scope(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='a', content='x', scope='project')) store.put(MemoryEntry(key='b', content='y', scope='global')) assert len(store.list_all(scope='project')) == 1 def test_search_scope(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='a', content='hello world', scope='project')) store.put(MemoryEntry(key='b', content='hello world', scope='global')) assert len(store.search('hello', scope='project')) == 1 def test_search_empty_query(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) store.put(MemoryEntry(key='k', content='v')) assert store.search('') == [] def test_load_malformed_json(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' path.write_text('not json at all', encoding='utf-8') - store = FileStore(path) + store = FileMemoryStore(path) assert store.list_all() == [] def test_load_wrong_structure(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' path.write_text('["a", "b"]', encoding='utf-8') - store = FileStore(path) + store = FileMemoryStore(path) assert store.list_all() == [] def test_load_missing_entry_fields(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' path.write_text('{"k": {"not_a_key": "oops"}}', encoding='utf-8') - store = FileStore(path) + store = FileMemoryStore(path) assert store.list_all() == [] def test_scope_persists(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store1 = FileStore(path) + store1 = FileMemoryStore(path) store1.put(MemoryEntry(key='k', content='v', scope='session')) - store2 = FileStore(path) + store2 = FileMemoryStore(path) entry = store2.get('k') assert entry is not None assert entry.scope == 'session' @@ -457,16 +457,16 @@ def test_scope_persists(self, tmp_path: Path) -> None: def test_expires_at_persists(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() - store1 = FileStore(path) + store1 = FileMemoryStore(path) store1.put(MemoryEntry(key='k', content='v', expires_at=future)) - store2 = FileStore(path) + store2 = FileMemoryStore(path) entry = store2.get('k') assert entry is not None assert entry.expires_at == future def test_save_drops_expired(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) store.put(MemoryEntry(key='alive', content='fresh')) @@ -477,7 +477,7 @@ def test_save_drops_expired(self, tmp_path: Path) -> None: def test_get_filters_expired(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' - store = FileStore(path) + store = FileMemoryStore(path) future = (datetime.now(timezone.utc) + timedelta(seconds=1)).isoformat() store.put(MemoryEntry(key='soon', content='v', expires_at=future)) # Manually backdate by mutating the in-memory entry's expires_at @@ -538,12 +538,12 @@ def test_serialization_name(self) -> None: def test_from_spec_default(self) -> None: cap = Memory.from_spec() - assert isinstance(cap.store, InMemoryStore) + assert isinstance(cap.store, DictMemoryStore) def test_from_spec_file(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' cap = Memory.from_spec(backend='file', path=str(path)) - assert isinstance(cap.store, FileStore) + assert isinstance(cap.store, FileMemoryStore) def test_from_spec_unknown_backend(self) -> None: import pytest @@ -553,7 +553,7 @@ def test_from_spec_unknown_backend(self) -> None: def test_from_spec_explicit_memory_backend(self) -> None: cap = Memory.from_spec(backend='memory') - assert isinstance(cap.store, InMemoryStore) + assert isinstance(cap.store, DictMemoryStore) def test_from_spec_with_options(self, tmp_path: Path) -> None: cap = Memory.from_spec( @@ -562,13 +562,13 @@ def test_from_spec_with_options(self, tmp_path: Path) -> None: inject_memories_in_instructions=False, max_instructions_memories=10, ) - assert isinstance(cap.store, FileStore) + assert isinstance(cap.store, FileMemoryStore) assert cap.inject_memories_in_instructions is False assert cap.max_instructions_memories == 10 def test_default_store(self) -> None: cap: Memory[None] = Memory() - assert isinstance(cap.store, InMemoryStore) + assert isinstance(cap.store, DictMemoryStore) def test_get_toolset_returns_function_toolset(self) -> None: cap: Memory[None] = Memory() @@ -590,14 +590,14 @@ class TestMemoryTools: """Test the tool functions exposed by the Memory capability.""" @staticmethod - def _get_tools(store: InMemoryStore | None = None) -> dict[str, Any]: - cap: Memory[None] = Memory(store=store or InMemoryStore()) + def _get_tools(store: DictMemoryStore | None = None) -> dict[str, Any]: + cap: Memory[None] = Memory(store=store or DictMemoryStore()) toolset = cap.get_toolset() assert isinstance(toolset, FunctionToolset) return {name: tool.function for name, tool in toolset.tools.items()} def test_save_and_recall(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) result = tools['save_memory']('greeting', 'hello world') assert result == 'Memory saved: greeting' @@ -610,14 +610,14 @@ def test_recall_missing(self) -> None: assert 'No memory found' in tools['recall_memory']('nope') def test_recall_expired(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() store.put(MemoryEntry(key='old', content='stale', expires_at=past)) tools = self._get_tools(store) assert 'No memory found' in tools['recall_memory']('old') def test_save_updates_existing(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v1') original = store.get('k') @@ -632,7 +632,7 @@ def test_save_updates_existing(self) -> None: assert updated.created_at == original_created def test_save_with_tags(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v', ['tag1', 'tag2']) entry = store.get('k') @@ -640,7 +640,7 @@ def test_save_with_tags(self) -> None: assert entry.tags == ['tag1', 'tag2'] def test_save_with_scope(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v', None, 'project') entry = store.get('k') @@ -648,7 +648,7 @@ def test_save_with_scope(self) -> None: assert entry.scope == 'project' def test_save_with_ttl(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v', None, 'global', 60) entry = store.get('k') @@ -660,7 +660,7 @@ def test_save_with_ttl(self) -> None: assert expires < datetime.now(timezone.utc) + timedelta(minutes=61) def test_search(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('user_name', 'Alice') tools['save_memory']('color', 'blue') @@ -674,7 +674,7 @@ def test_search_no_results(self) -> None: assert 'No memories found' in tools['search_memories']('zzz') def test_search_with_scope(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('a', 'hello world', None, 'project') tools['save_memory']('b', 'hello world', None, 'global') @@ -687,7 +687,7 @@ def test_list_empty(self) -> None: assert tools['list_memories']() == 'No memories stored.' def test_list_with_entries(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('a', 'alpha') tools['save_memory']('b', 'beta') @@ -696,7 +696,7 @@ def test_list_with_entries(self) -> None: assert '[b] beta' in result def test_list_with_scope(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('a', 'alpha', None, 'project') tools['save_memory']('b', 'beta', None, 'global') @@ -705,7 +705,7 @@ def test_list_with_scope(self) -> None: assert '[b]' not in result def test_delete_existing(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v') assert tools['delete_memory']('k') == 'Memory deleted: k' @@ -716,7 +716,7 @@ def test_delete_missing(self) -> None: assert 'No memory found' in tools['delete_memory']('nope') def test_save_with_ttl_zero(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = self._get_tools(store) tools['save_memory']('k', 'v', None, 'global', 0) # TTL=0 expires immediately; get() filters it out @@ -730,7 +730,7 @@ def test_save_with_ttl_zero(self) -> None: class TestDedupWarning: def test_similar_key_logs_warning(self, caplog: Any) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = TestMemoryTools._get_tools(store) tools['save_memory']('abcdefghij_x', 'first value') with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): @@ -738,7 +738,7 @@ def test_similar_key_logs_warning(self, caplog: Any) -> None: assert any('possible duplicate' in record.message.lower() for record in caplog.records) def test_different_keys_no_warning(self, caplog: Any) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = TestMemoryTools._get_tools(store) tools['save_memory']('first_key_long', 'first value') with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): @@ -746,7 +746,7 @@ def test_different_keys_no_warning(self, caplog: Any) -> None: assert not any('possible duplicate' in record.message.lower() for record in caplog.records) def test_short_keys_no_warning(self, caplog: Any) -> None: - store = InMemoryStore() + store = DictMemoryStore() tools = TestMemoryTools._get_tools(store) tools['save_memory']('abc', 'first value') with caplog.at_level(logging.WARNING, logger='pydantic_harness.memory'): @@ -779,7 +779,7 @@ def test_instructions_with_no_memories(self) -> None: assert 'Currently stored memories' not in text def test_instructions_with_memories(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='user', content='Alice')) cap: Memory[None] = Memory(store=store) text = cap.build_instructions(self._make_ctx()) @@ -787,7 +787,7 @@ def test_instructions_with_memories(self) -> None: assert '[user] Alice' in text def test_instructions_respects_max(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() for i in range(25): store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) cap: Memory[None] = Memory(store=store, max_instructions_memories=5) @@ -795,14 +795,14 @@ def test_instructions_respects_max(self) -> None: assert '... and 20 more' in text def test_instructions_disabled(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() store.put(MemoryEntry(key='k', content='v')) cap: Memory[None] = Memory(store=store, inject_memories_in_instructions=False) text = cap.build_instructions(self._make_ctx()) assert 'Currently stored memories' not in text def test_instructions_exact_max_no_overflow(self) -> None: - store = InMemoryStore() + store = DictMemoryStore() for i in range(5): store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) cap: Memory[None] = Memory(store=store, max_instructions_memories=5) @@ -817,10 +817,10 @@ def test_instructions_exact_max_no_overflow(self) -> None: class TestMemoryStoreProtocol: def test_in_memory_store_satisfies_protocol(self) -> None: - assert isinstance(InMemoryStore(), MemoryStore) + assert isinstance(DictMemoryStore(), MemoryStore) def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: - assert isinstance(FileStore(tmp_path / 'mem.json'), MemoryStore) + assert isinstance(FileMemoryStore(tmp_path / 'mem.json'), MemoryStore) # --- AbstractCapability conformance --- From 48bc0f577fa346bd21d2ccda77b1a8ddc47002b8 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:30:57 -0500 Subject: [PATCH 14/21] feat(memory): add summary, metadata, read_only, char_limit, importance to MemoryEntry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five additive fields aligning MemoryEntry with the structured-record conventions in Mem0/LangGraph/Letta: - summary: short version preferred over content for system-prompt injection (used by build_instructions in a follow-up commit) - metadata: JSON-serializable structured attributes for filterable search (used by search filter= in a follow-up commit) - read_only: pin entry against modification by save_memory/delete_memory tools — useful for system-curated facts (persona, policies) - char_limit: hard cap on content length, enforced at MemoryEntry construction; raises ValueError when exceeded - importance: relevance booster for search scoring The save_memory tool now accepts summary and importance as optional LLM-facing parameters; metadata, read_only, and char_limit are dev-only (set via direct MemoryEntry construction). save_memory and delete_memory refuse to modify read_only entries with a clear message. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/pydantic_harness/memory.py | 49 ++++++++++++++++++++++++++ tests/test_memory.py | 63 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index d640cff..d325825 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -43,6 +43,11 @@ class MemoryEntryDict(_MemoryEntryDictRequired, total=False): expires_at: str | None created_at: str updated_at: str + summary: str | None + metadata: dict[str, object] + read_only: bool + char_limit: int | None + importance: float | None @dataclass @@ -70,6 +75,28 @@ class MemoryEntry: updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) """ISO 8601 timestamp of the last update.""" + summary: str | None = None + """Optional short summary used by `Memory.build_instructions` when injecting this entry into the system prompt; falls back to `content` when None.""" + + metadata: dict[str, object] = field(default_factory=lambda: dict[str, object]()) + """Structured attributes for filterable search (`MemoryStore.search(filter=...)`). Values must be JSON-serializable.""" + + read_only: bool = False + """If True, the agent's `save_memory` and `delete_memory` tools refuse to modify this entry. Programmatic access via the store is unrestricted.""" + + char_limit: int | None = None + """Optional hard cap on `content` length (chars). Enforced at `MemoryEntry` construction; raises `ValueError` if exceeded.""" + + importance: float | None = None + """Optional relevance booster used by `MemoryStore.search` scoring when set.""" + + def __post_init__(self) -> None: + """Validate `char_limit` immediately so dev errors surface at construction.""" + if self.char_limit is not None and len(self.content) > self.char_limit: + raise ValueError( + f'MemoryEntry {self.key!r} content is {len(self.content)} chars, exceeds char_limit={self.char_limit}', + ) + def is_expired(self) -> bool: """Return True if this entry has passed its expiration time. @@ -93,6 +120,11 @@ def to_dict(self) -> MemoryEntryDict: 'expires_at': self.expires_at, 'created_at': self.created_at, 'updated_at': self.updated_at, + 'summary': self.summary, + 'metadata': self.metadata, + 'read_only': self.read_only, + 'char_limit': self.char_limit, + 'importance': self.importance, } @classmethod @@ -106,6 +138,11 @@ def from_dict(cls, data: MemoryEntryDict) -> MemoryEntry: expires_at=data.get('expires_at'), created_at=data.get('created_at', ''), updated_at=data.get('updated_at', ''), + summary=data.get('summary'), + metadata=data.get('metadata', {}), + read_only=data.get('read_only', False), + char_limit=data.get('char_limit'), + importance=data.get('importance'), ) @@ -401,6 +438,8 @@ def save_memory( tags: list[str] | None = None, scope: str = 'global', ttl_minutes: int | None = None, + summary: str | None = None, + importance: float | None = None, ) -> str: """Save or update a memory entry. @@ -410,11 +449,16 @@ def save_memory( tags: Optional tags for categorization and search. scope: Namespace scope (default `'global'`). ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. + summary: Optional short summary; preferred over `content` when injecting into the system prompt. + importance: Optional relevance booster (e.g., 0.0–1.0); used by search scoring. """ now = datetime.now(timezone.utc) now_iso = now.isoformat() existing = store.get(key) + if existing is not None and existing.read_only: + return f'Memory {key!r} is read-only and cannot be modified.' + # Dedup warning: check for similar keys among existing entries for existing_entry in store.list_all(): if _simple_similarity(key, existing_entry.key): @@ -436,6 +480,8 @@ def save_memory( expires_at=expires_at, created_at=existing.created_at if existing else now_iso, updated_at=now_iso, + summary=summary, + importance=importance, ) store.put(entry) return f'Memory saved: {key}' @@ -480,6 +526,9 @@ def delete_memory(key: str) -> str: Args: key: The key of the memory to delete. """ + entry = store.get(key) + if entry is not None and entry.read_only: + return f'Memory {key!r} is read-only and cannot be deleted.' if store.delete(key): return f'Memory deleted: {key}' return f'No memory found for key: {key}' diff --git a/tests/test_memory.py b/tests/test_memory.py index 62ac8e2..9fd9817 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -70,6 +70,36 @@ def test_is_expired_past(self) -> None: entry = MemoryEntry(key='k', content='v', expires_at=past) assert entry.is_expired() + def test_default_new_fields(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.summary is None + assert entry.metadata == {} + assert entry.read_only is False + assert entry.char_limit is None + assert entry.importance is None + + def test_round_trip_with_new_fields(self) -> None: + entry = MemoryEntry( + key='k', + content='v', + summary='short', + metadata={'priority': 1, 'source': 'manual'}, + read_only=True, + char_limit=100, + importance=0.8, + ) + assert MemoryEntry.from_dict(entry.to_dict()) == entry + + def test_char_limit_enforced(self) -> None: + import pytest + + with pytest.raises(ValueError, match='exceeds char_limit'): + MemoryEntry(key='k', content='hello world', char_limit=5) + + def test_char_limit_allows_exact(self) -> None: + # Exactly at the limit is allowed + MemoryEntry(key='k', content='hello', char_limit=5) + # --- _score_entry --- @@ -388,6 +418,11 @@ def test_file_format(self, tmp_path: Path) -> None: 'expires_at': None, 'created_at': 'c', 'updated_at': 'u', + 'summary': None, + 'metadata': {}, + 'read_only': False, + 'char_limit': None, + 'importance': None, } } @@ -724,6 +759,34 @@ def test_save_with_ttl_zero(self) -> None: # recall_memory should likewise report no memory assert 'No memory found' in tools['recall_memory']('k') + def test_save_with_summary_and_importance(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'long content here', None, 'global', None, 'short', 0.9) + entry = store.get('k') + assert entry is not None + assert entry.summary == 'short' + assert entry.importance == 0.9 + + def test_save_refuses_read_only(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='persona', content='locked', read_only=True)) + tools = self._get_tools(store) + result = tools['save_memory']('persona', 'overwrite attempt') + assert 'read-only' in result.lower() + # Original content preserved + entry = store.get('persona') + assert entry is not None + assert entry.content == 'locked' + + def test_delete_refuses_read_only(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='persona', content='locked', read_only=True)) + tools = self._get_tools(store) + result = tools['delete_memory']('persona') + assert 'read-only' in result.lower() + assert store.get('persona') is not None + # --- Dedup warning --- From b6ef5c5b3fc354ce0af707dcc76f6871151eef58 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:36:30 -0500 Subject: [PATCH 15/21] feat(memory): replace scope: str with namespace: tuple[str, ...] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Borrowed from LangGraph BaseStore. Hierarchical namespaces enable patterns like ('users', 'alice', 'prefs') and ('agents', 'planner', 'facts') that flat scope strings can't represent. Filters use prefix matching: namespace=('users',) matches all entries under that root. Changes: - MemoryEntry.scope (str, default 'global') → namespace (tuple[str, ...], default ('global',)) - MemoryStore.list_all/search now take namespace=tuple[str, ...] | None - New MemoryStore.list_namespaces(prefix, suffix, max_depth) returns unique namespaces in the store, sorted - save_memory/search_memories/list_memories tools accept list[str] and coerce to tuple internally - format_entry shows nested namespaces as 'a/b/c'; the default ('global',) is still omitted for brevity - Added _namespace_matches helper for prefix-match logic - Tests cover prefix matching, max_depth truncation, suffix filtering, and persistence of nested namespaces Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/memory/coding_assistant.py | 4 +- examples/memory/personal_assistant.py | 4 +- src/pydantic_harness/memory.py | 119 ++++++++++++++++----- tests/test_memory.py | 143 ++++++++++++++++++-------- 4 files changed, 199 insertions(+), 71 deletions(-) diff --git a/examples/memory/coding_assistant.py b/examples/memory/coding_assistant.py index 36a9c91..a3a1637 100644 --- a/examples/memory/coding_assistant.py +++ b/examples/memory/coding_assistant.py @@ -27,7 +27,7 @@ def main() -> None: system_prompt=( 'You are a coding assistant that learns from user corrections. ' 'When the user gives you a coding rule or correction, save it as a memory ' - 'with scope "rules" and tags like ["python", "style"] or ["typescript", "testing"]. ' + 'with namespace ["rules"] and tags like ["python", "style"] or ["typescript", "testing"]. ' 'Use descriptive keys like "rule_python_fstrings" or "rule_ts_const". ' 'When asked to write code, search your memories for relevant rules first.' ), @@ -46,7 +46,7 @@ def main() -> None: rules = store.list_all() print(f'\nRules stored: {len(rules)}') for r in rules: - print(f' [{r.key}] {r.content} (scope={r.scope}, tags={r.tags})') + print(f' [{r.key}] {r.content} (namespace={r.namespace}, tags={r.tags})') assert len(rules) >= 3, f'Expected at least 3 rules saved, got {len(rules)}' diff --git a/examples/memory/personal_assistant.py b/examples/memory/personal_assistant.py index c3c266a..849e9a4 100644 --- a/examples/memory/personal_assistant.py +++ b/examples/memory/personal_assistant.py @@ -31,7 +31,7 @@ def main() -> None: system_prompt=( 'You are a helpful personal assistant. ' 'When the user tells you about their preferences, save each one as a memory ' - 'with scope "user_prefs" and appropriate tags. ' + 'with namespace ["user_prefs"] and appropriate tags. ' 'Use descriptive keys like "preferred_name" or "theme_preference".' ), ) @@ -44,7 +44,7 @@ def main() -> None: entries = store.list_all() print(f'\nMemories after session 1: {len(entries)}') for e in entries: - print(f' [{e.key}] {e.content} (tags={e.tags}, scope={e.scope})') + print(f' [{e.key}] {e.content} (tags={e.tags}, namespace={e.namespace})') assert len(entries) >= 2, f'Expected at least 2 memories saved, got {len(entries)}' all_content = ' '.join(e.content.lower() for e in entries) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index d325825..73ec6f1 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -39,7 +39,7 @@ class MemoryEntryDict(_MemoryEntryDictRequired, total=False): """ tags: list[str] - scope: str + namespace: list[str] expires_at: str | None created_at: str updated_at: str @@ -63,8 +63,14 @@ class MemoryEntry: tags: list[str] = field(default_factory=lambda: list[str]()) """Optional tags for categorization and search.""" - scope: str = 'global' - """Namespace scope for this memory (default `'global'`).""" + namespace: tuple[str, ...] = ('global',) + """Hierarchical namespace for this memory. + + A tuple of strings forming a path-like namespace (e.g., `('users', 'alice')`, + `('agents', 'planner', 'facts')`). Filters in `list_all`/`search` use prefix + matching: `namespace=('users',)` matches `('users', 'alice')` and + `('users', 'bob')`. Default `('global',)` is a single-segment namespace. + """ expires_at: str | None = None """Optional ISO 8601 expiration timestamp. `None` means no expiry.""" @@ -116,7 +122,7 @@ def to_dict(self) -> MemoryEntryDict: 'key': self.key, 'content': self.content, 'tags': self.tags, - 'scope': self.scope, + 'namespace': list(self.namespace), 'expires_at': self.expires_at, 'created_at': self.created_at, 'updated_at': self.updated_at, @@ -134,7 +140,7 @@ def from_dict(cls, data: MemoryEntryDict) -> MemoryEntry: key=data['key'], content=data['content'], tags=data.get('tags', []), - scope=data.get('scope', 'global'), + namespace=tuple(data.get('namespace', ('global',))), expires_at=data.get('expires_at'), created_at=data.get('created_at', ''), updated_at=data.get('updated_at', ''), @@ -168,6 +174,17 @@ def _score_entry(entry: MemoryEntry, words: list[str]) -> int: return score +def _namespace_matches(entry_ns: tuple[str, ...], filter_prefix: tuple[str, ...]) -> bool: + """Return True if `entry_ns` starts with `filter_prefix`. + + Empty `filter_prefix` matches any namespace (use `None` in callers to mean + "no filter" — this helper assumes the caller has decided to filter). + """ + if len(entry_ns) < len(filter_prefix): + return False + return entry_ns[: len(filter_prefix)] == filter_prefix + + def _simple_similarity(a: str, b: str) -> bool: """Return True if two keys share the same first 10 characters and differ only slightly. @@ -213,14 +230,35 @@ def delete(self, key: str) -> bool: # pragma: no cover """Delete a memory entry by key. Returns True if it existed.""" ... - def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover - """Return all non-expired entries, optionally filtered by scope.""" + def list_all(self, *, namespace: tuple[str, ...] | None = None) -> list[MemoryEntry]: # pragma: no cover + """Return all non-expired entries, optionally filtered by namespace prefix.""" ... - def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: # pragma: no cover + def search( + self, + query: str, + *, + namespace: tuple[str, ...] | None = None, + ) -> list[MemoryEntry]: # pragma: no cover """Search non-expired entries with word-boundary matching, sorted by relevance.""" ... + def list_namespaces( # pragma: no cover + self, + *, + prefix: tuple[str, ...] | None = None, + suffix: tuple[str, ...] | None = None, + max_depth: int | None = None, + ) -> list[tuple[str, ...]]: + """List unique namespaces among non-expired entries, optionally filtered. + + Args: + prefix: Only include namespaces starting with this prefix. + suffix: Only include namespaces ending with this suffix. + max_depth: Truncate each namespace to at most this many segments before deduplication. + """ + ... + class _BaseDictStore: """Base class for dict-backed memory stores.""" @@ -248,15 +286,20 @@ def _gc_expired(self) -> None: for key in expired_keys: del self._entries[key] - def list_all(self, *, scope: str | None = None) -> list[MemoryEntry]: - """Return all non-expired entries, optionally filtered by scope.""" + def list_all(self, *, namespace: tuple[str, ...] | None = None) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by namespace prefix.""" return [ entry for entry in self._entries.values() - if not entry.is_expired() and (scope is None or entry.scope == scope) + if not entry.is_expired() and (namespace is None or _namespace_matches(entry.namespace, namespace)) ] - def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: + def search( + self, + query: str, + *, + namespace: tuple[str, ...] | None = None, + ) -> list[MemoryEntry]: """Search non-expired entries with word-boundary matching, sorted by relevance.""" words = query.lower().split() if not words: @@ -265,7 +308,7 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: for entry in self._entries.values(): if entry.is_expired(): continue - if scope is not None and entry.scope != scope: + if namespace is not None and not _namespace_matches(entry.namespace, namespace): continue score = _score_entry(entry, words) if score > 0: @@ -273,6 +316,28 @@ def search(self, query: str, *, scope: str | None = None) -> list[MemoryEntry]: scored.sort(key=lambda pair: pair[0], reverse=True) return [entry for _, entry in scored] + def list_namespaces( + self, + *, + prefix: tuple[str, ...] | None = None, + suffix: tuple[str, ...] | None = None, + max_depth: int | None = None, + ) -> list[tuple[str, ...]]: + """List unique namespaces among non-expired entries.""" + seen: set[tuple[str, ...]] = set() + for entry in self._entries.values(): + if entry.is_expired(): + continue + ns = entry.namespace + if max_depth is not None: + ns = ns[:max_depth] + if prefix is not None and not _namespace_matches(ns, prefix): + continue + if suffix is not None and (len(ns) < len(suffix) or ns[-len(suffix) :] != suffix): + continue + seen.add(ns) + return sorted(seen) + class DictMemoryStore(_BaseDictStore): """Dict-based in-memory store, suitable for testing. @@ -334,8 +399,8 @@ def format_entry(entry: MemoryEntry) -> str: extras: list[str] = [] if entry.tags: extras.append(f'tags: {", ".join(entry.tags)}') - if entry.scope != 'global': - extras.append(f'scope: {entry.scope}') + if entry.namespace != ('global',): + extras.append(f'namespace: {"/".join(entry.namespace)}') if entry.expires_at is not None: extras.append(f'expires: {entry.expires_at}') if extras: @@ -436,7 +501,7 @@ def save_memory( key: str, content: str, tags: list[str] | None = None, - scope: str = 'global', + namespace: list[str] | None = None, ttl_minutes: int | None = None, summary: str | None = None, importance: float | None = None, @@ -447,7 +512,8 @@ def save_memory( key: Unique key for this memory. content: The content to remember. tags: Optional tags for categorization and search. - scope: Namespace scope (default `'global'`). + namespace: Optional hierarchical namespace as a list of segments + (e.g., `['users', 'alice']`). Defaults to `['global']`. ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. summary: Optional short summary; preferred over `content` when injecting into the system prompt. importance: Optional relevance booster (e.g., 0.0–1.0); used by search scoring. @@ -472,11 +538,12 @@ def save_memory( if ttl_minutes is not None: expires_at = (now + timedelta(minutes=ttl_minutes)).isoformat() + ns: tuple[str, ...] = tuple(namespace) if namespace else ('global',) entry = MemoryEntry( key=key, content=content, tags=tags or [], - scope=scope, + namespace=ns, expires_at=expires_at, created_at=existing.created_at if existing else now_iso, updated_at=now_iso, @@ -497,25 +564,27 @@ def recall_memory(key: str) -> str: return f'No memory found for key: {key}' return format_entry(entry) - def search_memories(query: str, scope: str | None = None) -> str: + def search_memories(query: str, namespace: list[str] | None = None) -> str: """Search memories by word-boundary matching on keys, content, or tags, sorted by relevance. Args: query: The search query string (space-separated words). - scope: Optional scope to restrict the search to. + namespace: Optional namespace prefix to restrict the search to (e.g., `['users']`). """ - results = store.search(query, scope=scope) + ns: tuple[str, ...] | None = tuple(namespace) if namespace else None + results = store.search(query, namespace=ns) if not results: return f'No memories found matching: {query}' return '\n'.join(format_entry(entry) for entry in results) - def list_memories(scope: str | None = None) -> str: - """List all stored memories, optionally filtered by scope. + def list_memories(namespace: list[str] | None = None) -> str: + """List all stored memories, optionally filtered by namespace prefix. Args: - scope: Optional scope to filter by. + namespace: Optional namespace prefix to filter by (e.g., `['users']`). """ - entries = store.list_all(scope=scope) + ns: tuple[str, ...] | None = tuple(namespace) if namespace else None + entries = store.list_all(namespace=ns) if not entries: return 'No memories stored.' return '\n'.join(format_entry(entry) for entry in entries) diff --git a/tests/test_memory.py b/tests/test_memory.py index 9fd9817..ac1ea20 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -32,7 +32,7 @@ def test_round_trip(self) -> None: key='k', content='v', tags=['a', 'b'], - scope='project', + namespace=('project',), expires_at='2099-01-01T00:00:00+00:00', created_at='t1', updated_at='t2', @@ -42,7 +42,7 @@ def test_round_trip(self) -> None: def test_from_dict_defaults(self) -> None: entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) assert entry.tags == [] - assert entry.scope == 'global' + assert entry.namespace == ('global',) assert entry.expires_at is None assert entry.created_at == '' assert entry.updated_at == '' @@ -52,9 +52,9 @@ def test_default_timestamps(self) -> None: assert entry.created_at # non-empty ISO string assert entry.updated_at - def test_default_scope(self) -> None: + def test_default_namespace(self) -> None: entry = MemoryEntry(key='k', content='v') - assert entry.scope == 'global' + assert entry.namespace == ('global',) def test_is_expired_no_expiry(self) -> None: entry = MemoryEntry(key='k', content='v') @@ -259,17 +259,17 @@ def test_get_filters_expired(self) -> None: def test_list_all_scope_filter(self) -> None: store = DictMemoryStore() - store.put(MemoryEntry(key='a', content='x', scope='project')) - store.put(MemoryEntry(key='b', content='y', scope='global')) - entries = store.list_all(scope='project') + store.put(MemoryEntry(key='a', content='x', namespace=('project',))) + store.put(MemoryEntry(key='b', content='y', namespace=('global',))) + entries = store.list_all(namespace=('project',)) assert len(entries) == 1 assert entries[0].key == 'a' def test_list_all_scope_none_returns_all(self) -> None: store = DictMemoryStore() - store.put(MemoryEntry(key='a', content='x', scope='project')) - store.put(MemoryEntry(key='b', content='y', scope='global')) - assert len(store.list_all(scope=None)) == 2 + store.put(MemoryEntry(key='a', content='x', namespace=('project',))) + store.put(MemoryEntry(key='b', content='y', namespace=('global',))) + assert len(store.list_all(namespace=None)) == 2 def test_search_by_key(self) -> None: store = DictMemoryStore() @@ -317,9 +317,9 @@ def test_search_filters_expired(self) -> None: def test_search_scope_filter(self) -> None: store = DictMemoryStore() - store.put(MemoryEntry(key='a', content='hello world', scope='project')) - store.put(MemoryEntry(key='b', content='hello world', scope='global')) - results = store.search('hello', scope='project') + store.put(MemoryEntry(key='a', content='hello world', namespace=('project',))) + store.put(MemoryEntry(key='b', content='hello world', namespace=('global',))) + results = store.search('hello', namespace=('project',)) assert len(results) == 1 assert results[0].key == 'a' @@ -339,6 +339,44 @@ def test_search_empty_query(self) -> None: store.put(MemoryEntry(key='k', content='v')) assert store.search('') == [] + def test_list_namespaces(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice'))) + store.put(MemoryEntry(key='b', content='y', namespace=('users', 'bob'))) + store.put(MemoryEntry(key='c', content='z', namespace=('agents', 'planner'))) + assert store.list_namespaces() == [('agents', 'planner'), ('users', 'alice'), ('users', 'bob')] + + def test_list_namespaces_prefix(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice'))) + store.put(MemoryEntry(key='b', content='y', namespace=('agents', 'planner'))) + assert store.list_namespaces(prefix=('users',)) == [('users', 'alice')] + + def test_list_namespaces_max_depth(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice', 'prefs'))) + store.put(MemoryEntry(key='b', content='y', namespace=('users', 'bob', 'prefs'))) + # Truncate to depth 1 → both collapse to ('users',) + assert store.list_namespaces(max_depth=1) == [('users',)] + + def test_list_namespaces_suffix(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice', 'prefs'))) + store.put(MemoryEntry(key='b', content='y', namespace=('agents', 'planner', 'prefs'))) + assert store.list_namespaces(suffix=('prefs',)) == [ + ('agents', 'planner', 'prefs'), + ('users', 'alice', 'prefs'), + ] + + def test_list_all_namespace_prefix_match(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice'))) + store.put(MemoryEntry(key='b', content='y', namespace=('users', 'bob'))) + store.put(MemoryEntry(key='c', content='z', namespace=('agents',))) + # Prefix ('users',) matches both alice and bob + results = store.list_all(namespace=('users',)) + assert {e.key for e in results} == {'a', 'b'} + # --- FileMemoryStore --- @@ -414,7 +452,7 @@ def test_file_format(self, tmp_path: Path) -> None: 'key': 'k', 'content': 'v', 'tags': ['t'], - 'scope': 'global', + 'namespace': ['global'], 'expires_at': None, 'created_at': 'c', 'updated_at': 'u', @@ -445,16 +483,16 @@ def test_search_filters_expired(self, tmp_path: Path) -> None: def test_list_all_scope(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store = FileMemoryStore(path) - store.put(MemoryEntry(key='a', content='x', scope='project')) - store.put(MemoryEntry(key='b', content='y', scope='global')) - assert len(store.list_all(scope='project')) == 1 + store.put(MemoryEntry(key='a', content='x', namespace=('project',))) + store.put(MemoryEntry(key='b', content='y', namespace=('global',))) + assert len(store.list_all(namespace=('project',))) == 1 def test_search_scope(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store = FileMemoryStore(path) - store.put(MemoryEntry(key='a', content='hello world', scope='project')) - store.put(MemoryEntry(key='b', content='hello world', scope='global')) - assert len(store.search('hello', scope='project')) == 1 + store.put(MemoryEntry(key='a', content='hello world', namespace=('project',))) + store.put(MemoryEntry(key='b', content='hello world', namespace=('global',))) + assert len(store.search('hello', namespace=('project',))) == 1 def test_search_empty_query(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' @@ -480,14 +518,23 @@ def test_load_missing_entry_fields(self, tmp_path: Path) -> None: store = FileMemoryStore(path) assert store.list_all() == [] - def test_scope_persists(self, tmp_path: Path) -> None: + def test_namespace_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileMemoryStore(path) + store1.put(MemoryEntry(key='k', content='v', namespace=('session',))) + store2 = FileMemoryStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.namespace == ('session',) + + def test_nested_namespace_persists(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' store1 = FileMemoryStore(path) - store1.put(MemoryEntry(key='k', content='v', scope='session')) + store1.put(MemoryEntry(key='k', content='v', namespace=('users', 'alice', 'prefs'))) store2 = FileMemoryStore(path) entry = store2.get('k') assert entry is not None - assert entry.scope == 'session' + assert entry.namespace == ('users', 'alice', 'prefs') def test_expires_at_persists(self, tmp_path: Path) -> None: path = tmp_path / 'mem.json' @@ -533,12 +580,16 @@ def test_with_tags(self) -> None: entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) assert format_entry(entry) == '[k] hello (tags: a, b)' - def test_with_scope(self) -> None: - entry = MemoryEntry(key='k', content='hello', scope='project') - assert format_entry(entry) == '[k] hello (scope: project)' + def test_with_namespace(self) -> None: + entry = MemoryEntry(key='k', content='hello', namespace=('project',)) + assert format_entry(entry) == '[k] hello (namespace: project)' + + def test_with_nested_namespace(self) -> None: + entry = MemoryEntry(key='k', content='hello', namespace=('users', 'alice')) + assert format_entry(entry) == '[k] hello (namespace: users/alice)' - def test_global_scope_omitted(self) -> None: - entry = MemoryEntry(key='k', content='hello', scope='global') + def test_global_namespace_omitted(self) -> None: + entry = MemoryEntry(key='k', content='hello', namespace=('global',)) assert format_entry(entry) == '[k] hello' def test_with_expires_at(self) -> None: @@ -550,10 +601,10 @@ def test_all_extras(self) -> None: key='k', content='hello', tags=['t'], - scope='project', + namespace=('project',), expires_at='2099-01-01T00:00:00+00:00', ) - assert format_entry(entry) == '[k] hello (tags: t; scope: project; expires: 2099-01-01T00:00:00+00:00)' + assert format_entry(entry) == '[k] hello (tags: t; namespace: project; expires: 2099-01-01T00:00:00+00:00)' def test_empty_content(self) -> None: entry = MemoryEntry(key='k', content='') @@ -674,18 +725,26 @@ def test_save_with_tags(self) -> None: assert entry is not None assert entry.tags == ['tag1', 'tag2'] - def test_save_with_scope(self) -> None: + def test_save_with_namespace(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, ['project']) + entry = store.get('k') + assert entry is not None + assert entry.namespace == ('project',) + + def test_save_with_nested_namespace(self) -> None: store = DictMemoryStore() tools = self._get_tools(store) - tools['save_memory']('k', 'v', None, 'project') + tools['save_memory']('k', 'v', None, ['users', 'alice']) entry = store.get('k') assert entry is not None - assert entry.scope == 'project' + assert entry.namespace == ('users', 'alice') def test_save_with_ttl(self) -> None: store = DictMemoryStore() tools = self._get_tools(store) - tools['save_memory']('k', 'v', None, 'global', 60) + tools['save_memory']('k', 'v', None, ['global'], 60) entry = store.get('k') assert entry is not None assert entry.expires_at is not None @@ -711,9 +770,9 @@ def test_search_no_results(self) -> None: def test_search_with_scope(self) -> None: store = DictMemoryStore() tools = self._get_tools(store) - tools['save_memory']('a', 'hello world', None, 'project') - tools['save_memory']('b', 'hello world', None, 'global') - result = tools['search_memories']('hello', 'project') + tools['save_memory']('a', 'hello world', None, ['project']) + tools['save_memory']('b', 'hello world', None, ['global']) + result = tools['search_memories']('hello', ['project']) assert '[a]' in result assert '[b]' not in result @@ -733,9 +792,9 @@ def test_list_with_entries(self) -> None: def test_list_with_scope(self) -> None: store = DictMemoryStore() tools = self._get_tools(store) - tools['save_memory']('a', 'alpha', None, 'project') - tools['save_memory']('b', 'beta', None, 'global') - result = tools['list_memories']('project') + tools['save_memory']('a', 'alpha', None, ['project']) + tools['save_memory']('b', 'beta', None, ['global']) + result = tools['list_memories'](['project']) assert '[a] alpha' in result assert '[b]' not in result @@ -753,7 +812,7 @@ def test_delete_missing(self) -> None: def test_save_with_ttl_zero(self) -> None: store = DictMemoryStore() tools = self._get_tools(store) - tools['save_memory']('k', 'v', None, 'global', 0) + tools['save_memory']('k', 'v', None, ['global'], 0) # TTL=0 expires immediately; get() filters it out assert store.get('k') is None # recall_memory should likewise report no memory @@ -762,7 +821,7 @@ def test_save_with_ttl_zero(self) -> None: def test_save_with_summary_and_importance(self) -> None: store = DictMemoryStore() tools = self._get_tools(store) - tools['save_memory']('k', 'long content here', None, 'global', None, 'short', 0.9) + tools['save_memory']('k', 'long content here', None, ['global'], None, 'short', 0.9) entry = store.get('k') assert entry is not None assert entry.summary == 'short' From ea57378d25fe1477c15242bca216bf7494961677 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:40:02 -0500 Subject: [PATCH 16/21] feat(memory): add filter, recency_scorer, importance to MemoryStore.search MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three additions to MemoryStore.search and list_all (filter only on list_all): - filter: dict[str, object] | None — equality match against MemoryEntry.metadata. Drops entries that don't match all keys. - recency_scorer: Callable[[MemoryEntry], float] — pluggable per-call recency boost added to the base keyword-match score. Built-in exponential_decay(half_life_days, weight) factory ships as the Memory capability default (30-day half-life, weight 0.5). Set to None to disable. - entry.importance: float | None — when set, added to the search score unconditionally so devs/agents can pin entries above keyword matches. Score formula: keyword_match_count + (importance or 0) + (recency or 0). Entries with zero keyword match are still excluded — recency and importance only re-rank within the matched set. Also adds RecencyScorer type alias, both exported from pydantic_harness for custom scorers (e.g. linear decay, importance-only ranking). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/pydantic_harness/__init__.py | 13 +++- src/pydantic_harness/memory.py | 114 +++++++++++++++++++++++++++---- tests/test_memory.py | 67 ++++++++++++++++++ 3 files changed, 178 insertions(+), 16 deletions(-) diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 6b699d8..69fd8c0 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,7 +7,16 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -from pydantic_harness.memory import DictMemoryStore, FileMemoryStore, Memory, MemoryEntry, MemoryEntryDict, MemoryStore +from pydantic_harness.memory import ( + DictMemoryStore, + FileMemoryStore, + Memory, + MemoryEntry, + MemoryEntryDict, + MemoryStore, + RecencyScorer, + exponential_decay, +) __all__: list[str] = [ 'DictMemoryStore', @@ -16,4 +25,6 @@ 'MemoryEntry', 'MemoryEntryDict', 'MemoryStore', + 'RecencyScorer', + 'exponential_decay', ] diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 73ec6f1..e009656 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -10,10 +10,11 @@ import json import logging import re +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, Protocol, TypedDict, runtime_checkable +from typing import Any, Protocol, TypeAlias, TypedDict, runtime_checkable from pydantic_ai._instructions import AgentInstructions from pydantic_ai.capabilities.abstract import AbstractCapability @@ -174,6 +175,48 @@ def _score_entry(entry: MemoryEntry, words: list[str]) -> int: return score +RecencyScorer: TypeAlias = Callable[['MemoryEntry'], float] +"""Callable that maps a `MemoryEntry` to a recency score (typically in `[0, 1]`). + +Added to the keyword-match score in `MemoryStore.search` to bias results toward fresher entries. +Use the built-in `exponential_decay` factory or supply any callable. +""" + + +def exponential_decay(*, half_life_days: float = 30.0, weight: float = 1.0) -> RecencyScorer: + """Build a recency scorer with exponential decay over `entry.updated_at`. + + Args: + half_life_days: Age (in days) at which the decay value is halved. Default `30.0`. + weight: Multiplier applied to the decay value. Default `1.0`. + + Returns: + A `RecencyScorer` callable. Entries with an unparseable `updated_at` + return `0.0`; future-dated entries return the full `weight`. + """ + + def scorer(entry: MemoryEntry) -> float: + try: + updated = datetime.fromisoformat(entry.updated_at) + except ValueError: + return 0.0 + age_seconds = (datetime.now(timezone.utc) - updated).total_seconds() + if age_seconds < 0: + return weight + age_days = age_seconds / 86400 + return weight * (2 ** (-age_days / half_life_days)) + + return scorer + + +def _matches_filter(entry: MemoryEntry, filter_: dict[str, object]) -> bool: + """Return True if all filter keys match `entry.metadata` values exactly.""" + for key, value in filter_.items(): + if entry.metadata.get(key) != value: + return False + return True + + def _namespace_matches(entry_ns: tuple[str, ...], filter_prefix: tuple[str, ...]) -> bool: """Return True if `entry_ns` starts with `filter_prefix`. @@ -230,17 +273,28 @@ def delete(self, key: str) -> bool: # pragma: no cover """Delete a memory entry by key. Returns True if it existed.""" ... - def list_all(self, *, namespace: tuple[str, ...] | None = None) -> list[MemoryEntry]: # pragma: no cover - """Return all non-expired entries, optionally filtered by namespace prefix.""" + def list_all( # pragma: no cover + self, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + ) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by namespace prefix and metadata equality.""" ... - def search( + def search( # pragma: no cover self, query: str, *, namespace: tuple[str, ...] | None = None, - ) -> list[MemoryEntry]: # pragma: no cover - """Search non-expired entries with word-boundary matching, sorted by relevance.""" + filter: dict[str, object] | None = None, + recency_scorer: RecencyScorer | None = None, + ) -> list[MemoryEntry]: + """Search non-expired entries, sorted by relevance. + + Score = keyword-match count + entry.importance (if set) + recency_scorer(entry) (if provided). + Entries with zero keyword match are excluded regardless of recency or importance. + """ ... def list_namespaces( # pragma: no cover @@ -286,12 +340,19 @@ def _gc_expired(self) -> None: for key in expired_keys: del self._entries[key] - def list_all(self, *, namespace: tuple[str, ...] | None = None) -> list[MemoryEntry]: - """Return all non-expired entries, optionally filtered by namespace prefix.""" + def list_all( + self, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + ) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by namespace prefix and metadata equality.""" return [ entry for entry in self._entries.values() - if not entry.is_expired() and (namespace is None or _namespace_matches(entry.namespace, namespace)) + if not entry.is_expired() + and (namespace is None or _namespace_matches(entry.namespace, namespace)) + and (filter is None or _matches_filter(entry, filter)) ] def search( @@ -299,20 +360,33 @@ def search( query: str, *, namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + recency_scorer: RecencyScorer | None = None, ) -> list[MemoryEntry]: - """Search non-expired entries with word-boundary matching, sorted by relevance.""" + """Search non-expired entries with word-boundary matching, sorted by relevance. + + Score = keyword-match count + `entry.importance` (if set) + `recency_scorer(entry)` (if provided). + """ words = query.lower().split() if not words: return [] - scored: list[tuple[int, MemoryEntry]] = [] + scored: list[tuple[float, MemoryEntry]] = [] for entry in self._entries.values(): if entry.is_expired(): continue if namespace is not None and not _namespace_matches(entry.namespace, namespace): continue - score = _score_entry(entry, words) - if score > 0: - scored.append((score, entry)) + if filter is not None and not _matches_filter(entry, filter): + continue + base = _score_entry(entry, words) + if base == 0: + continue + score: float = float(base) + if entry.importance is not None: + score += entry.importance + if recency_scorer is not None: + score += recency_scorer(entry) + scored.append((score, entry)) scored.sort(key=lambda pair: pair[0], reverse=True) return [entry for _, entry in scored] @@ -433,6 +507,15 @@ class Memory(AbstractCapability[AgentDepsT]): max_instructions_memories: int = 20 """Maximum number of memories to include in the system prompt.""" + recency_scorer: RecencyScorer | None = field( + default_factory=lambda: exponential_decay(half_life_days=30.0, weight=0.5), + ) + """Recency scorer threaded into `search_memories` to bias results toward fresher entries. + + Defaults to `exponential_decay(half_life_days=30, weight=0.5)`. Set to `None` to disable. + Pass any `Callable[[MemoryEntry], float]` for custom decay shapes. + """ + @classmethod def get_serialization_name(cls) -> str | None: """Return the name used for spec serialization.""" @@ -496,6 +579,7 @@ def get_toolset(self) -> AgentToolset[AgentDepsT] | None: requiring anything from the agent's `deps`. """ store = self.store + recency_scorer = self.recency_scorer def save_memory( key: str, @@ -572,7 +656,7 @@ def search_memories(query: str, namespace: list[str] | None = None) -> str: namespace: Optional namespace prefix to restrict the search to (e.g., `['users']`). """ ns: tuple[str, ...] | None = tuple(namespace) if namespace else None - results = store.search(query, namespace=ns) + results = store.search(query, namespace=ns, recency_scorer=recency_scorer) if not results: return f'No memories found matching: {query}' return '\n'.join(format_entry(entry) for entry in results) diff --git a/tests/test_memory.py b/tests/test_memory.py index ac1ea20..2972711 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -20,6 +20,7 @@ MemoryStore, _score_entry, _simple_similarity, + exponential_decay, format_entry, ) @@ -377,6 +378,72 @@ def test_list_all_namespace_prefix_match(self) -> None: results = store.list_all(namespace=('users',)) assert {e.key for e in results} == {'a', 'b'} + def test_list_all_with_filter(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', metadata={'priority': 1})) + store.put(MemoryEntry(key='b', content='y', metadata={'priority': 2})) + results = store.list_all(filter={'priority': 1}) + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_with_filter(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', metadata={'source': 'manual'})) + store.put(MemoryEntry(key='b', content='hello world', metadata={'source': 'auto'})) + results = store.search('hello', filter={'source': 'manual'}) + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_filter_no_match(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', metadata={'source': 'manual'})) + assert store.search('hello', filter={'source': 'nonexistent'}) == [] + + def test_search_importance_boosts(self) -> None: + store = DictMemoryStore() + # Both match 'hello' once in content; importance differentiates them + store.put(MemoryEntry(key='boring', content='hello there')) + store.put(MemoryEntry(key='vip', content='hello there', importance=2.0)) + results = store.search('hello') + assert results[0].key == 'vip' + assert results[1].key == 'boring' + + def test_search_with_recency_scorer(self) -> None: + store = DictMemoryStore() + # Identical keyword scores; recent entry should rank first via recency_scorer. + old = (datetime.now(timezone.utc) - timedelta(days=365)).isoformat() + new = datetime.now(timezone.utc).isoformat() + store.put(MemoryEntry(key='ancient', content='hello there', updated_at=old)) + store.put(MemoryEntry(key='fresh', content='hello there', updated_at=new)) + results = store.search('hello', recency_scorer=exponential_decay(half_life_days=30.0)) + assert results[0].key == 'fresh' + assert results[1].key == 'ancient' + + +class TestExponentialDecay: + def test_fresh_entry_full_weight(self) -> None: + scorer = exponential_decay(half_life_days=30.0, weight=1.0) + entry = MemoryEntry(key='k', content='v') # updated_at = now + # Should be very close to 1.0 (essentially zero seconds elapsed) + assert 0.99 < scorer(entry) <= 1.0 + + def test_one_half_life_old(self) -> None: + scorer = exponential_decay(half_life_days=30.0, weight=1.0) + thirty_days_ago = (datetime.now(timezone.utc) - timedelta(days=30)).isoformat() + entry = MemoryEntry(key='k', content='v', updated_at=thirty_days_ago) + # ~0.5 (within float-precision tolerance) + assert 0.49 < scorer(entry) < 0.51 + + def test_invalid_updated_at_returns_zero(self) -> None: + scorer = exponential_decay() + entry = MemoryEntry(key='k', content='v', updated_at='') + assert scorer(entry) == 0.0 + + def test_weight_multiplier(self) -> None: + scorer = exponential_decay(half_life_days=30.0, weight=0.5) + entry = MemoryEntry(key='k', content='v') # fresh + assert 0.49 < scorer(entry) <= 0.5 + # --- FileMemoryStore --- From 8ad614e133460df6a8a93124811dab01352c7bd6 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:41:50 -0500 Subject: [PATCH 17/21] feat(memory): summary preference, byte_budget, and read_only pinning in instructions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit build_instructions now: - Prefers entry.summary over entry.content when injecting into the system prompt (via format_entry's new prefer_summary flag), keeping token budgets predictable for entries with long bodies. - Respects an optional byte_budget: int | None on Memory — non-pinned entries are skipped once adding the next would exceed the cap (UTF-8 bytes). - Always injects read_only=True entries regardless of count cap or byte budget. Pinned entries are listed first. Selection precedence: pinned (always) → up to max_instructions_memories non-pinned, capped by byte_budget when set. Overflow is reported with a "... and N more" suffix nudging the LLM toward search/list tools. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/pydantic_harness/memory.py | 80 ++++++++++++++++++++++++++++------ tests/test_memory.py | 75 +++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 14 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index e009656..ab588c7 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -467,9 +467,17 @@ def delete(self, key: str) -> bool: return existed -def format_entry(entry: MemoryEntry) -> str: - """Format a memory entry as a human-readable string.""" - line = f'[{entry.key}] {entry.content}' +def format_entry(entry: MemoryEntry, *, prefer_summary: bool = False) -> str: + """Format a memory entry as a human-readable string. + + Args: + entry: The entry to format. + prefer_summary: If True and `entry.summary` is set, render the summary + in place of the full content. Used by `Memory.build_instructions` + to keep system-prompt injection short. Defaults to False (full content). + """ + body = entry.summary if (prefer_summary and entry.summary is not None) else entry.content + line = f'[{entry.key}] {body}' extras: list[str] = [] if entry.tags: extras.append(f'tags: {", ".join(entry.tags)}') @@ -505,7 +513,18 @@ class Memory(AbstractCapability[AgentDepsT]): """Whether to inject existing memories into the system prompt at run start.""" max_instructions_memories: int = 20 - """Maximum number of memories to include in the system prompt.""" + """Maximum number of non-pinned memories to include in the system prompt. + + `read_only=True` entries always inject regardless of this cap. + """ + + byte_budget: int | None = None + """Optional UTF-8 byte cap on the injected memories block. + + When set, non-pinned entries are skipped once adding the next would exceed + the budget. `read_only=True` entries always inject regardless of this cap. + Default `None` = no byte cap (only the count cap applies). + """ recency_scorer: RecencyScorer | None = field( default_factory=lambda: exponential_decay(half_life_days=30.0, weight=0.5), @@ -552,20 +571,53 @@ def from_spec( ) def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: - """Build dynamic instructions that include currently stored memories.""" + """Build dynamic instructions that include currently stored memories. + + Selection rules: + - `read_only=True` entries always inject (bypass both count cap and byte budget). + - Non-pinned entries respect `max_instructions_memories` and `byte_budget`. + - When `entry.summary` is set, it's preferred over `entry.content` to save tokens. + - Pinned entries are listed first. + """ parts: list[str] = [ 'You have access to a persistent memory system. ' 'Use it to save important information that should be remembered across conversations.', ] - if self.inject_memories_in_instructions: - entries = self.store.list_all() - if entries: - parts.append('\nCurrently stored memories:') - for entry in entries[: self.max_instructions_memories]: - parts.append(f'- {format_entry(entry)}') - overflow = len(entries) - self.max_instructions_memories - if overflow > 0: - parts.append(f'... and {overflow} more (use list_memories or search_memories to see all).') + if not self.inject_memories_in_instructions: + return '\n'.join(parts) + + entries = self.store.list_all() + if not entries: + return '\n'.join(parts) + + parts.append('\nCurrently stored memories:') + + # Pinned first, then the rest in store order + ordered = sorted(entries, key=lambda e: not e.read_only) + + formatted: list[str] = [] + used_bytes = 0 + consumed_non_pinned = 0 + for entry in ordered: + line = f'- {format_entry(entry, prefer_summary=True)}' + line_bytes = len(line.encode('utf-8')) + if entry.read_only: + formatted.append(line) + used_bytes += line_bytes + continue + if consumed_non_pinned >= self.max_instructions_memories: + break + if self.byte_budget is not None and used_bytes + line_bytes > self.byte_budget: + break + formatted.append(line) + used_bytes += line_bytes + consumed_non_pinned += 1 + + parts.extend(formatted) + + overflow = len(entries) - len(formatted) + if overflow > 0: + parts.append(f'... and {overflow} more (use list_memories or search_memories to see all).') return '\n'.join(parts) def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: diff --git a/tests/test_memory.py b/tests/test_memory.py index 2972711..3713869 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -651,6 +651,19 @@ def test_with_namespace(self) -> None: entry = MemoryEntry(key='k', content='hello', namespace=('project',)) assert format_entry(entry) == '[k] hello (namespace: project)' + def test_prefer_summary_uses_summary(self) -> None: + entry = MemoryEntry(key='k', content='long content body', summary='short') + assert format_entry(entry, prefer_summary=True) == '[k] short' + + def test_prefer_summary_falls_back_to_content(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert format_entry(entry, prefer_summary=True) == '[k] hello' + + def test_default_prefers_content(self) -> None: + entry = MemoryEntry(key='k', content='full content', summary='short') + # prefer_summary defaults to False + assert format_entry(entry) == '[k] full content' + def test_with_nested_namespace(self) -> None: entry = MemoryEntry(key='k', content='hello', namespace=('users', 'alice')) assert format_entry(entry) == '[k] hello (namespace: users/alice)' @@ -1000,6 +1013,68 @@ def test_instructions_exact_max_no_overflow(self) -> None: assert '[k0]' in text assert '[k4]' in text + def test_instructions_use_summary(self) -> None: + store = DictMemoryStore() + store.put( + MemoryEntry(key='user', content='Alice is a 30-year-old data scientist', summary='Alice, data scientist') + ) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert '[user] Alice, data scientist' in text + assert '30-year-old' not in text + + def test_instructions_falls_back_to_content_without_summary(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user', content='Alice is a 30-year-old data scientist')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert 'Alice is a 30-year-old data scientist' in text + + def test_instructions_byte_budget_truncates(self) -> None: + store = DictMemoryStore() + for i in range(10): + store.put(MemoryEntry(key=f'k{i}', content=f'content for entry {i}')) + # Tight budget: only ~2 short lines fit + cap: Memory[None] = Memory(store=store, byte_budget=80) + text = cap.build_instructions(self._make_ctx()) + assert '[k0]' in text + # k9 (last) almost certainly excluded by byte budget + assert '[k9]' not in text + assert 'more (use list_memories' in text + + def test_instructions_pinned_always_injected(self) -> None: + store = DictMemoryStore() + # 5 normal entries + 1 pinned, with max=2: pinned is always shown + for i in range(5): + store.put(MemoryEntry(key=f'normal{i}', content=f'v{i}')) + store.put(MemoryEntry(key='persona', content='I am a helpful assistant', read_only=True)) + cap: Memory[None] = Memory(store=store, max_instructions_memories=2) + text = cap.build_instructions(self._make_ctx()) + assert '[persona] I am a helpful assistant' in text + # Two normal entries + the pinned one = 3 displayed + normal_displayed = sum(1 for i in range(5) if f'[normal{i}]' in text) + assert normal_displayed == 2 + + def test_instructions_pinned_bypasses_byte_budget(self) -> None: + store = DictMemoryStore() + # Big pinned entry that would overflow a tight budget + store.put(MemoryEntry(key='persona', content='x' * 500, read_only=True)) + store.put(MemoryEntry(key='normal', content='small')) + cap: Memory[None] = Memory(store=store, byte_budget=50) + text = cap.build_instructions(self._make_ctx()) + assert '[persona]' in text + # Pinned takes the budget; normal gets dropped + assert '[normal]' not in text + + def test_instructions_pinned_listed_first(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='zzz_normal', content='v')) + store.put(MemoryEntry(key='aaa_pinned', content='p', read_only=True)) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + # Pinned line appears before normal line + assert text.index('[aaa_pinned]') < text.index('[zzz_normal]') + # --- MemoryStore protocol --- From f0ff1dadc7f941f3022153bda09d137f85848cb1 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:43:40 -0500 Subject: [PATCH 18/21] feat(memory): add tool_descriptions override on Memory Per-tool description override at toolset construction. Keys are tool names (save_memory, recall_memory, search_memories, list_memories, delete_memory); values replace the docstring used by the LLM. Useful for nudging agent behavior toward project-specific conventions (e.g., "Save aggressively, even small facts"). Borrowed from pydantic-deepagents AgentMemoryToolset. Tools without an override fall back to their docstring as before. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/pydantic_harness/memory.py | 16 +++++++++++----- tests/test_memory.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index ab588c7..27c4ddb 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -535,6 +535,11 @@ class Memory(AbstractCapability[AgentDepsT]): Pass any `Callable[[MemoryEntry], float]` for custom decay shapes. """ + tool_descriptions: dict[str, str] = field(default_factory=lambda: dict[str, str]()) + """Per-tool description overrides. Keys are tool names (`save_memory`, `recall_memory`, + `search_memories`, `list_memories`, `delete_memory`); values replace the docstring used + by the LLM. Useful for nudging the agent (e.g., "Save aggressively, even small facts").""" + @classmethod def get_serialization_name(cls) -> str | None: """Return the name used for spec serialization.""" @@ -738,12 +743,13 @@ def delete_memory(key: str) -> str: return f'Memory deleted: {key}' return f'No memory found for key: {key}' + descs = self.tool_descriptions return FunctionToolset( [ - Tool(save_memory, takes_ctx=False), - Tool(recall_memory, takes_ctx=False), - Tool(search_memories, takes_ctx=False), - Tool(list_memories, takes_ctx=False), - Tool(delete_memory, takes_ctx=False), + Tool(save_memory, takes_ctx=False, description=descs.get('save_memory')), + Tool(recall_memory, takes_ctx=False, description=descs.get('recall_memory')), + Tool(search_memories, takes_ctx=False, description=descs.get('search_memories')), + Tool(list_memories, takes_ctx=False, description=descs.get('list_memories')), + Tool(delete_memory, takes_ctx=False, description=descs.get('delete_memory')), ], ) diff --git a/tests/test_memory.py b/tests/test_memory.py index 3713869..789af8f 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -748,6 +748,16 @@ def test_toolset_has_expected_tools(self) -> None: tool_names = set(toolset.tools.keys()) assert tool_names == {'save_memory', 'recall_memory', 'search_memories', 'list_memories', 'delete_memory'} + def test_tool_descriptions_override(self) -> None: + custom = 'CUSTOM: Save aggressively. Tiny facts count.' + cap: Memory[None] = Memory(tool_descriptions={'save_memory': custom}) + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + assert toolset.tools['save_memory'].description == custom + # Other tools fall back to docstring (not the override) + recall_desc = toolset.tools['recall_memory'].description + assert recall_desc is not None and 'Recall' in recall_desc + # --- Tool functions (via closure) --- From 23e9846bb9e2c4b6b396d7b4dd9cf1eabf6509d4 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:47:14 -0500 Subject: [PATCH 19/21] docs(memory): add Postgres reference backend at examples/memory/postgres_store.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reference implementation of MemoryStore against Postgres via psycopg. ~150 LOC. Schema: key TEXT PK, namespace TEXT[], data JSONB. Implements all six Protocol methods (get, put, delete, list_all, search, list_namespaces). Filtering happens in SQL (namespace prefix via array slicing, metadata equality via JSONB ops); keyword scoring runs in Python after the DB pre-filter, matching DictMemoryStore semantics. Documented as a starting point — production users will want connection pooling, schema migrations, and full-text or pgvector ranking. Also exempts examples/ from the ruff D (pydocstyle) ruleset, mirroring the existing tests/ exemption — example scripts are illustrative, not API surface. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/memory/postgres_store.py | 167 ++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 168 insertions(+) create mode 100644 examples/memory/postgres_store.py diff --git a/examples/memory/postgres_store.py b/examples/memory/postgres_store.py new file mode 100644 index 0000000..acf4935 --- /dev/null +++ b/examples/memory/postgres_store.py @@ -0,0 +1,167 @@ +"""Postgres backend for the Memory capability — reference implementation. + +Shows how to implement `MemoryStore` against Postgres using `psycopg`. This is +a starting point, not a production-ready backend: adapt to your deployment +(connection pooling, schema migrations, async, full-text search via tsvector). + +For semantic retrieval, swap the `search` implementation for one that runs +`SELECT ... ORDER BY embedding <=> %s` against a pgvector column. + +Schema: + CREATE TABLE memories ( + key TEXT PRIMARY KEY, + namespace TEXT[] NOT NULL DEFAULT ARRAY['global'], + data JSONB NOT NULL + ); + +Usage: + pip install 'psycopg[binary]' + + import psycopg + from pydantic_ai import Agent + from pydantic_harness.memory import Memory + from examples.memory.postgres_store import PostgresMemoryStore + + conn = psycopg.connect('postgresql://localhost/myapp') + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=PostgresMemoryStore(conn))]) +""" + +from __future__ import annotations + +import json +import re +from typing import Any + +import psycopg + +from pydantic_harness.memory import MemoryEntry, RecencyScorer + +SCHEMA = """ +CREATE TABLE IF NOT EXISTS memories ( + key TEXT PRIMARY KEY, + namespace TEXT[] NOT NULL DEFAULT ARRAY['global'], + data JSONB NOT NULL +); +CREATE INDEX IF NOT EXISTS memories_namespace_idx ON memories USING GIN (namespace); +""" + + +class PostgresMemoryStore: + """`MemoryStore` backed by Postgres via psycopg. + + Implements the full `MemoryStore` Protocol: `get`, `put`, `delete`, + `list_all`, `search`, `list_namespaces`. Filtering happens in SQL + (namespace prefix via array slicing, metadata equality via JSONB ops); + keyword scoring runs in Python after the DB pre-filter. + """ + + def __init__(self, conn: psycopg.Connection[Any]) -> None: + self._conn = conn + with self._conn.cursor() as cur: + cur.execute(SCHEMA) + self._conn.commit() + + def get(self, key: str) -> MemoryEntry | None: + with self._conn.cursor() as cur: + cur.execute('SELECT data FROM memories WHERE key = %s', (key,)) + row = cur.fetchone() + if row is None: + return None + entry = MemoryEntry.from_dict(row[0]) + return None if entry.is_expired() else entry + + def put(self, entry: MemoryEntry) -> None: + with self._conn.cursor() as cur: + cur.execute( + 'INSERT INTO memories (key, namespace, data) VALUES (%s, %s, %s) ' + 'ON CONFLICT (key) DO UPDATE SET namespace = EXCLUDED.namespace, data = EXCLUDED.data', + (entry.key, list(entry.namespace), json.dumps(entry.to_dict())), + ) + self._conn.commit() + + def delete(self, key: str) -> bool: + with self._conn.cursor() as cur: + cur.execute('DELETE FROM memories WHERE key = %s', (key,)) + deleted = cur.rowcount > 0 + self._conn.commit() + return deleted + + def list_all( + self, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + ) -> list[MemoryEntry]: + sql = 'SELECT data FROM memories WHERE TRUE' + params: list[Any] = [] + if namespace is not None: + # Prefix match: entry.namespace[1:N] = supplied namespace tuple + sql += ' AND namespace[1:%s] = %s' + params.extend([len(namespace), list(namespace)]) + if filter is not None: + for k, v in filter.items(): + sql += " AND (data -> 'metadata' ->> %s) = %s" + params.extend([k, str(v)]) + with self._conn.cursor() as cur: + cur.execute(sql, params) + rows = cur.fetchall() + entries = [MemoryEntry.from_dict(r[0]) for r in rows] + return [e for e in entries if not e.is_expired()] + + def search( + self, + query: str, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + recency_scorer: RecencyScorer | None = None, + ) -> list[MemoryEntry]: + # SQL-side filter, Python-side keyword scoring matching DictMemoryStore semantics. + # Production: replace with full-text or pgvector ranking. + words = query.lower().split() + if not words: + return [] + candidates = self.list_all(namespace=namespace, filter=filter) + scored: list[tuple[float, MemoryEntry]] = [] + for entry in candidates: + base = 0 + for word in words: + pattern = re.compile(rf'(? list[tuple[str, ...]]: + with self._conn.cursor() as cur: + cur.execute('SELECT DISTINCT namespace FROM memories') + rows = cur.fetchall() + seen: set[tuple[str, ...]] = set() + for row in rows: + ns: tuple[str, ...] = tuple(row[0]) + if max_depth is not None: + ns = ns[:max_depth] + if prefix is not None and (len(ns) < len(prefix) or ns[: len(prefix)] != prefix): + continue + if suffix is not None and (len(ns) < len(suffix) or ns[-len(suffix) :] != suffix): + continue + seen.add(ns) + return sorted(seen) diff --git a/pyproject.toml b/pyproject.toml index 8c5e481..16560bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ extend-select = ['Q', 'RUF100', 'C90', 'UP', 'I', 'D', 'TID251'] [tool.ruff.lint.per-file-ignores] 'tests/**/*.py' = ['D'] +'examples/**/*.py' = ['D'] [tool.ruff.lint.flake8-quotes] inline-quotes = 'single' From 8f58a027b460997a91a851491f157bebdfe39a14 Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 09:49:30 -0500 Subject: [PATCH 20/21] docs(memory): add capability page, multi-agent example, refresh PLAN.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - docs/capabilities/memory.md — full user-facing capability docs covering quick start, MemoryEntry shape, built-in vs custom backends, namespaces, multi-agent shared store, search/recency, prompt-cache trade-off, tool description overrides, custom backends, and the known followups list (semantic retrieval, dedup hook). - Memory class docstring gains a multi-agent shared-store example so the pattern is discoverable from autocomplete without reading docs. - PLAN.md fully refreshed: reflects all current fields, the namespace tuple, recency scoring, byte budget, read_only pinning, the Postgres reference, and the Future Work followups. - docs/capabilities/index.md links the new memory page. Co-Authored-By: Claude Opus 4.7 (1M context) --- PLAN.md | 110 +++++++++++++++++------ docs/capabilities/index.md | 2 +- docs/capabilities/memory.md | 155 +++++++++++++++++++++++++++++++++ src/pydantic_harness/memory.py | 13 +++ 4 files changed, 250 insertions(+), 30 deletions(-) create mode 100644 docs/capabilities/memory.md diff --git a/PLAN.md b/PLAN.md index abb9074..8ca5a95 100644 --- a/PLAN.md +++ b/PLAN.md @@ -2,53 +2,105 @@ ## Summary -Implements a `Memory` capability (`AbstractCapability` subclass) that provides persistent key-value memory across agent sessions, referencing issues #30 and #31. +Implements a `Memory` capability (`AbstractCapability` subclass) providing +persistent key-value memory across agent sessions. References issues #30 and #31. + +User-facing docs: [`docs/capabilities/memory.md`](docs/capabilities/memory.md). ## Design ### Architecture - **`Memory`** dataclass extends `AbstractCapability[AgentDepsT]` - - `get_instructions()` returns a dynamic callable that injects stored memories into the system prompt at run start - - `get_toolset()` returns a `FunctionToolset` with five tools: `save_memory`, `recall_memory`, `search_memories`, `list_memories`, `delete_memory` - - Tool functions use closures over `self.store` (no dependency on agent `deps`) + - `get_instructions()` returns a dynamic callable injecting stored memories + into the system prompt at run start + - `get_toolset()` returns a `FunctionToolset` with five tools: `save_memory`, + `recall_memory`, `search_memories`, `list_memories`, `delete_memory` + - Per-tool description overrides via `tool_descriptions: dict[str, str]` + - Tool functions use closures over `self.store` and `self.recency_scorer` + (no dependency on agent `deps`) ### Storage -- **`MemoryStore`** protocol: pluggable backend with `get`, `put`, `delete`, `list_all`, `search` -- **`DictMemoryStore`**: dict-based, ephemeral, for testing (default) -- **`FileMemoryStore`**: JSON file on disk, reads on init, writes on every mutation - -### Memory Model - -- **`MemoryEntry`** dataclass: `key`, `content`, `tags` (list[str]), `scope`, `expires_at`, `created_at`, `updated_at` -- **`MemoryEntryDict`** TypedDict for serialization -- Word-boundary search with relevance scoring (case-insensitive) across key, content, and tags -- Scoping/namespaces via `scope` field with filtering on search/list -- TTL/expiration via `expires_at` with `is_expired()` auto-filtering -- Dedup warning on save when keys are similar (Levenshtein distance <= 2) - -### Spec Serialization - -- `Memory.get_serialization_name()` returns `"Memory"` -- `Memory.from_spec(backend="file", path="...")` creates a `FileMemoryStore`-backed instance +- **`MemoryStore`** Protocol: pluggable backend with six methods — `get`, `put`, + `delete`, `list_all`, `search`, `list_namespaces` +- **`DictMemoryStore`**: dict-based, ephemeral, for tests/scratch (default) +- **`FileMemoryStore(path)`**: JSON file on disk, reads on init, writes on every + mutation; drops expired entries on save +- Both extend `_BaseDictStore` for shared logic +- Custom backends: implement the Protocol. See + [`examples/memory/postgres_store.py`](examples/memory/postgres_store.py) for + a Postgres reference. + +### `MemoryEntry` + +Required: `key`, `content`. Optional fields: + +- `tags: list[str]` — LLM-set categorisation +- `namespace: tuple[str, ...]` — hierarchical namespace, prefix-matched +- `expires_at: str | None` — ISO 8601 wall-clock expiry (opt-in TTL) +- `created_at`, `updated_at: str` — ISO 8601 timestamps +- `summary: str | None` — preferred over `content` for prompt injection +- `metadata: dict[str, object]` — structured attributes; filterable via search +- `read_only: bool` — pin against agent modification (always injected) +- `char_limit: int | None` — hard cap on `content` length, enforced at construction +- `importance: float | None` — search-score booster + +`MemoryEntryDict` TypedDict covers the JSON serialisation form. + +### Search & retrieval + +- `search` score = keyword-match count + `entry.importance` (if set) + + `recency_scorer(entry)` (if provided) +- Word-boundary regex matching across `key`, `content`, `tags` (case-insensitive) +- `_` and `-` count as word separators +- Default recency scorer: `exponential_decay(half_life_days=30, weight=0.5)` +- `search`/`list_all` accept `namespace` (prefix match) and `filter` (metadata + equality) kwargs; `search` additionally accepts `recency_scorer` + +### Instructions injection + +- `read_only=True` entries always inject (bypass count cap and byte budget) +- Non-pinned entries respect `max_instructions_memories` (default 20) and + `byte_budget: int | None` (UTF-8 byte cap) +- `entry.summary` is preferred over `entry.content` to save tokens +- Pinned entries are listed first +- Disabled entirely via `inject_memories_in_instructions=False` (prompt-cache + mitigation for write-heavy workloads) + +### Spec serialisation + +- `Memory.get_serialization_name()` → `"Memory"` +- `Memory.from_spec(backend='memory'|'file', path=..., ...)` for declarative config ## Configuration | Field | Default | Description | -|-------|---------|-------------| +|---|---|---| | `store` | `DictMemoryStore()` | Storage backend | | `inject_memories_in_instructions` | `True` | Include memories in system prompt | -| `max_instructions_memories` | `20` | Cap on memories injected into prompt | +| `max_instructions_memories` | `20` | Cap on non-pinned memories injected | +| `byte_budget` | `None` | Optional UTF-8 byte cap on injection block | +| `recency_scorer` | `exponential_decay(half_life_days=30, weight=0.5)` | Or `None` to disable | +| `tool_descriptions` | `{}` | Per-tool description overrides | ## Files -- `src/pydantic_harness/memory.py` - Capability, stores, entry model -- `src/pydantic_harness/__init__.py` - Re-exports -- `tests/test_memory.py` - 113 tests covering all code paths +- `src/pydantic_harness/memory.py` — capability, stores, entry model, recency helpers +- `src/pydantic_harness/__init__.py` — re-exports +- `tests/test_memory.py` — 150 tests covering all code paths +- `examples/memory/*.py` — three runnable examples plus the Postgres reference +- `docs/capabilities/memory.md` — user-facing docs ## Future Work -- Semantic/vector search backend (e.g. embedding-based `MemoryStore`) -- Session-scoped memory isolation via `for_run()` -- SQLite / Redis backends for production persistence +- **Semantic retrieval** — `SemanticMemoryStore` Protocol extension and an + `EmbeddingStore` reference (numpy/cosine, or pgvector). Deferred until a + concrete backend drives the API design — premature design tends to lock in + the wrong shape. +- **Tool-history dedup hook** — suppress next-turn injection of memories the + LLM just saved (already visible in tool history). Has subtle semantics around + in-run updates; deferred until usage telemetry shows the magnitude of the + token-waste problem. +- **Deferred capability loading** (PR #5230 in pydantic-ai) — once that lands, + declare `id`/`description` on `Memory` to opt into deferred loading. diff --git a/docs/capabilities/index.md b/docs/capabilities/index.md index b12b547..dcef380 100644 --- a/docs/capabilities/index.md +++ b/docs/capabilities/index.md @@ -11,7 +11,7 @@ Pydantic AI agent via the `capabilities` parameter. | FileSystem | Read, write, and navigate the local filesystem | | Guardrails | Validate inputs/outputs and enforce cost and tool constraints | | KnowsCurrentTime | Inject the current date and time into the system prompt | -| Memory | Persistent key-value memory across agent sessions | +| [Memory](memory.md) | Persistent key-value memory across agent sessions | | Planning | Break complex tasks into plans before execution | | RepoContextInjection | Inject repository structure and context into the system prompt | | SecretMasking | Detect and redact secrets in agent inputs and outputs | diff --git a/docs/capabilities/memory.md b/docs/capabilities/memory.md new file mode 100644 index 0000000..9f05f3c --- /dev/null +++ b/docs/capabilities/memory.md @@ -0,0 +1,155 @@ +# Memory + +Persistent key-value memory across agent sessions. Provides five tools the LLM +can call (`save_memory`, `recall_memory`, `search_memories`, `list_memories`, +`delete_memory`) and injects currently-stored memories into the system prompt +each run. + +## Quick start + +```python +from pydantic_ai import Agent +from pydantic_harness import Memory + +agent = Agent('openai:gpt-4o', capabilities=[Memory()]) +``` + +By default `Memory()` uses an in-process `DictMemoryStore` — entries live only +for the lifetime of the Python process. Use `FileMemoryStore` for single-user +single-process persistence, or implement `MemoryStore` for anything else. + +## Built-in backends + +| Backend | Persistence | Concurrency | Use case | +|---|---|---|---| +| `DictMemoryStore` | None (in-process) | Single-thread | Tests, scratch agents | +| `FileMemoryStore(path)` | JSON file on disk | Single-process | Single-user CLI agents | + +For Postgres, Redis, vector DBs, etc. — implement the `MemoryStore` Protocol. +See [`examples/memory/postgres_store.py`](https://github.com/pydantic/pydantic-harness/blob/main/examples/memory/postgres_store.py) +for a reference implementation. + +## `MemoryEntry` fields + +| Field | Type | Default | Notes | +|---|---|---|---| +| `key` | `str` | required | Unique identifier | +| `content` | `str` | required | The fact itself | +| `tags` | `list[str]` | `[]` | LLM-set categorisation | +| `namespace` | `tuple[str, ...]` | `('global',)` | Hierarchical namespace; prefix-matched in queries | +| `expires_at` | `str \| None` | `None` | ISO 8601 wall-clock expiry; opt-in TTL | +| `created_at`, `updated_at` | `str` | now | ISO 8601 timestamps | +| `summary` | `str \| None` | `None` | Short version preferred over `content` for prompt injection | +| `metadata` | `dict[str, object]` | `{}` | Structured attributes; filterable via `search(filter=...)` | +| `read_only` | `bool` | `False` | If True, agent's tools refuse to modify | +| `char_limit` | `int \| None` | `None` | Optional hard cap on `content` length (raises at construction) | +| `importance` | `float \| None` | `None` | Search-score booster | + +## Namespaces + +Hierarchical, tuple-based. Filters in `list_all`/`search` use prefix matching: + +```python +store.put(MemoryEntry(key='a', content='...', namespace=('users', 'alice'))) +store.put(MemoryEntry(key='b', content='...', namespace=('users', 'bob'))) +store.put(MemoryEntry(key='c', content='...', namespace=('agents', 'planner'))) + +store.list_all(namespace=('users',)) # → [a, b] +store.list_namespaces(prefix=('users',)) # → [('users', 'alice'), ('users', 'bob')] +``` + +## Multi-agent shared memory + +One store, two agents, separate namespaces: + +```python +from pydantic_ai import Agent +from pydantic_harness import FileMemoryStore, Memory + +shared = FileMemoryStore('/var/lib/myapp/memory.json') + +planner = Agent('openai:gpt-4o', capabilities=[ + Memory(store=shared, byte_budget=2000), +]) +worker = Agent('openai:gpt-4o-mini', capabilities=[ + Memory(store=shared, byte_budget=500), +]) +``` + +Entries written by either agent are visible to both. Use `namespace=('agents', +'planner')` etc. on saves to keep their workspaces separate while still sharing +common facts in `('global',)`. + +## Search + +Word-boundary regex on key, content, and tags. Final score: + +``` +keyword_match_count + (entry.importance or 0) + (recency_scorer(entry) or 0) +``` + +Recency boost is enabled by default — `Memory` ships with +`exponential_decay(half_life_days=30, weight=0.5)`. Override with any callable: + +```python +from pydantic_harness import Memory, exponential_decay + +# Tighter half-life for fast-moving information +Memory(recency_scorer=exponential_decay(half_life_days=7)) + +# Custom: boost only entries with the 'pinned' tag +Memory(recency_scorer=lambda e: 1.0 if 'pinned' in e.tags else 0.0) + +# Disable recency entirely +Memory(recency_scorer=None) +``` + +## Prompt-cache trade-off + +Every save/delete changes the injected memories block in the system prompt, +invalidating the prompt-cache prefix. Read-heavy workloads keep the cache; +write-heavy workloads thrash. + +Mitigation: `Memory(inject_memories_in_instructions=False)` skips the +injection. The LLM reads memories only via explicit `list_memories` / +`search_memories` / `recall_memory` calls — system prompt prefix stays stable +across writes. + +For partial mitigation, set `byte_budget` to cap the injected block size. + +## Tool description overrides + +```python +Memory(tool_descriptions={ + 'save_memory': 'Save anything the user mentions about themselves, ' + 'even tiny details. Tag with "user_pref".', +}) +``` + +## Custom backends + +Implement the `MemoryStore` Protocol — six methods, all positional/kwarg-only: + +```python +from pydantic_harness import MemoryEntry, MemoryStore, RecencyScorer + +class MyStore: + def get(self, key: str) -> MemoryEntry | None: ... + def put(self, entry: MemoryEntry) -> None: ... + def delete(self, key: str) -> bool: ... + def list_all(self, *, namespace=None, filter=None): ... + def search(self, query, *, namespace=None, filter=None, recency_scorer=None): ... + def list_namespaces(self, *, prefix=None, suffix=None, max_depth=None): ... +``` + +Drop into any `Memory(store=MyStore())`. See the Postgres example for a +working reference. + +## Known followups + +- **Semantic retrieval**: `SemanticMemoryStore` Protocol extension and an + `EmbeddingStore` reference impl. Deferred until a concrete backend (Qdrant / + pgvector / LanceDB) drives the API design. +- **Tool-history dedup**: suppress next-turn injection of memories the LLM + just saved (already in tool history). Deferred — has subtle semantics around + updates that need real-world telemetry to design correctly. diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 27c4ddb..776e73c 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -504,6 +504,19 @@ class Memory(AbstractCapability[AgentDepsT]): agent = Agent('openai:gpt-4o', capabilities=[Memory(store=DictMemoryStore())]) ``` + + Multi-agent shared store: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_harness.memory import FileMemoryStore, Memory + + shared = FileMemoryStore('/var/lib/myapp/memory.json') + planner = Agent('openai:gpt-4o', capabilities=[Memory(store=shared, byte_budget=2000)]) + worker = Agent('openai:gpt-4o-mini', capabilities=[Memory(store=shared, byte_budget=500)]) + ``` + Both agents see the same entries; use distinct `namespace` tuples on + saves to keep their workspaces separate (e.g., `('agents', 'planner')` + vs `('agents', 'worker')`). """ store: MemoryStore = field(default_factory=DictMemoryStore) From 76f0a4edbfb1753c43918b1a76f3242395792b7c Mon Sep 17 00:00:00 2001 From: David Sanchez <64162682+dsfaccini@users.noreply.github.com> Date: Thu, 7 May 2026 10:19:23 -0500 Subject: [PATCH 21/21] feat(memory): dedup against tool history when injecting memories build_instructions now skips re-injecting an entry when the LLM has already seen its current value via a save_memory call in this run's tool history. Default ON (dedup_recent_saves=True); disable per-Memory instance via the flag. The check is content-aware: scans ctx.messages for ToolCallParts named 'save_memory', tracks the last (key, content) saved per key, and only suppresses an entry when entry.content == last_saved_content. If something updated the entry externally (another agent, manual store mutation, etc.), the saved content no longer matches the store, so we inject to let the LLM see the current value. read_only entries always inject regardless of dedup. This is the previously-deferred "tool-history dedup" followup; the content-equality safeguard resolves the in-run-update gotcha. Co-Authored-By: Claude Opus 4.7 (1M context) --- PLAN.md | 10 +++--- docs/capabilities/memory.md | 12 +++++-- src/pydantic_harness/memory.py | 49 ++++++++++++++++++++++++- tests/test_memory.py | 66 ++++++++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 9 deletions(-) diff --git a/PLAN.md b/PLAN.md index 8ca5a95..3ffd0f7 100644 --- a/PLAN.md +++ b/PLAN.md @@ -60,10 +60,13 @@ Required: `key`, `content`. Optional fields: ### Instructions injection -- `read_only=True` entries always inject (bypass count cap and byte budget) +- `read_only=True` entries always inject (bypass count cap, byte budget, and dedup) - Non-pinned entries respect `max_instructions_memories` (default 20) and `byte_budget: int | None` (UTF-8 byte cap) - `entry.summary` is preferred over `entry.content` to save tokens +- `dedup_recent_saves: bool = True` suppresses entries the LLM just saved in + this run's tool history, when the saved content still matches the store + entry (content-aware: if external state diverged, inject the current value) - Pinned entries are listed first - Disabled entirely via `inject_memories_in_instructions=False` (prompt-cache mitigation for write-heavy workloads) @@ -83,6 +86,7 @@ Required: `key`, `content`. Optional fields: | `byte_budget` | `None` | Optional UTF-8 byte cap on injection block | | `recency_scorer` | `exponential_decay(half_life_days=30, weight=0.5)` | Or `None` to disable | | `tool_descriptions` | `{}` | Per-tool description overrides | +| `dedup_recent_saves` | `True` | Suppress injection of entries just saved in this run | ## Files @@ -98,9 +102,5 @@ Required: `key`, `content`. Optional fields: `EmbeddingStore` reference (numpy/cosine, or pgvector). Deferred until a concrete backend drives the API design — premature design tends to lock in the wrong shape. -- **Tool-history dedup hook** — suppress next-turn injection of memories the - LLM just saved (already visible in tool history). Has subtle semantics around - in-run updates; deferred until usage telemetry shows the magnitude of the - token-waste problem. - **Deferred capability loading** (PR #5230 in pydantic-ai) — once that lands, declare `id`/`description` on `Memory` to opt into deferred loading. diff --git a/docs/capabilities/memory.md b/docs/capabilities/memory.md index 9f05f3c..1016e04 100644 --- a/docs/capabilities/memory.md +++ b/docs/capabilities/memory.md @@ -117,6 +117,15 @@ across writes. For partial mitigation, set `byte_budget` to cap the injected block size. +## Dedup against tool history + +Default `dedup_recent_saves=True` suppresses injection of an entry when the +LLM has already seen its current value via a `save_memory` call in this run's +tool history. Content-aware: if something updated the entry externally (e.g., +another agent), the saved content no longer matches the store, so the entry +is injected so the LLM sees the current value. `read_only=True` entries are +never suppressed. Disable with `dedup_recent_saves=False`. + ## Tool description overrides ```python @@ -150,6 +159,3 @@ working reference. - **Semantic retrieval**: `SemanticMemoryStore` Protocol extension and an `EmbeddingStore` reference impl. Deferred until a concrete backend (Qdrant / pgvector / LanceDB) drives the API design. -- **Tool-history dedup**: suppress next-turn injection of memories the LLM - just saved (already in tool history). Deferred — has subtle semantics around - updates that need real-world telemetry to design correctly. diff --git a/src/pydantic_harness/memory.py b/src/pydantic_harness/memory.py index 776e73c..3d4fd11 100644 --- a/src/pydantic_harness/memory.py +++ b/src/pydantic_harness/memory.py @@ -18,6 +18,7 @@ from pydantic_ai._instructions import AgentInstructions from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart from pydantic_ai.tools import AgentDepsT, RunContext, Tool from pydantic_ai.toolsets import AgentToolset from pydantic_ai.toolsets.function import FunctionToolset @@ -209,6 +210,32 @@ def scorer(entry: MemoryEntry) -> float: return scorer +def _saves_in_history(messages: list[ModelMessage]) -> dict[str, str]: + """Scan tool history for `save_memory` calls; return `{key: last saved content}`. + + Used by `Memory.build_instructions` to suppress re-injecting entries the LLM + just saved — when the saved content still matches the current store entry, + the LLM has already seen the value via the tool call/result in history, + so re-injecting it wastes tokens. If a key was saved multiple times, the + most recent save wins. + """ + last: dict[str, str] = {} + for msg in messages: + if not isinstance(msg, ModelResponse): + continue + for part in msg.parts: + if not isinstance(part, ToolCallPart): + continue + if part.tool_name != 'save_memory': + continue + args = part.args_as_dict() + key = args.get('key') + content = args.get('content') + if isinstance(key, str) and isinstance(content, str): + last[key] = content + return last + + def _matches_filter(entry: MemoryEntry, filter_: dict[str, object]) -> bool: """Return True if all filter keys match `entry.metadata` values exactly.""" for key, value in filter_.items(): @@ -553,6 +580,16 @@ class Memory(AbstractCapability[AgentDepsT]): `search_memories`, `list_memories`, `delete_memory`); values replace the docstring used by the LLM. Useful for nudging the agent (e.g., "Save aggressively, even small facts").""" + dedup_recent_saves: bool = True + """When True, suppress injection of entries that match a `save_memory` call + in the current run's tool history (the LLM has already seen the value). + + The check is content-aware: if the store entry's `content` differs from the + most recent saved content (e.g., another process updated the entry), the + entry is injected so the LLM sees the current value. `read_only=True` + entries are never suppressed. + """ + @classmethod def get_serialization_name(cls) -> str | None: """Return the name used for spec serialization.""" @@ -592,9 +629,12 @@ def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: """Build dynamic instructions that include currently stored memories. Selection rules: - - `read_only=True` entries always inject (bypass both count cap and byte budget). + - `read_only=True` entries always inject (bypass count cap, byte budget, and dedup). - Non-pinned entries respect `max_instructions_memories` and `byte_budget`. - When `entry.summary` is set, it's preferred over `entry.content` to save tokens. + - When `dedup_recent_saves` is True, entries whose current content matches + the most recent `save_memory` call in this run's tool history are suppressed + (the LLM has already seen the value via the tool call). - Pinned entries are listed first. """ parts: list[str] = [ @@ -610,6 +650,8 @@ def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: parts.append('\nCurrently stored memories:') + recent_saves: dict[str, str] = _saves_in_history(ctx.messages) if self.dedup_recent_saves else {} + # Pinned first, then the rest in store order ordered = sorted(entries, key=lambda e: not e.read_only) @@ -617,6 +659,11 @@ def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: used_bytes = 0 consumed_non_pinned = 0 for entry in ordered: + # read_only entries bypass dedup, count cap, and byte budget + if not entry.read_only: + saved_content = recent_saves.get(entry.key) + if saved_content is not None and saved_content == entry.content: + continue line = f'- {format_entry(entry, prefer_summary=True)}' line_bytes = len(line.encode('utf-8')) if entry.read_only: diff --git a/tests/test_memory.py b/tests/test_memory.py index 789af8f..4a59074 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1085,6 +1085,72 @@ def test_instructions_pinned_listed_first(self) -> None: # Pinned line appears before normal line assert text.index('[aaa_pinned]') < text.index('[zzz_normal]') + @staticmethod + def _ctx_with_save(key: str, content: str) -> RunContext[None]: + from unittest.mock import MagicMock + + from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart + + msgs: list[ModelMessage] = [ + ModelResponse(parts=[ToolCallPart(tool_name='save_memory', args={'key': key, 'content': content})]), + ] + return RunContext(deps=None, model=MagicMock(), usage=RunUsage(), messages=msgs) + + def test_instructions_dedup_suppresses_recently_saved(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._ctx_with_save('user_name', 'Alice')) + # LLM has already seen 'Alice' via the save_memory call in tool history + assert '[user_name]' not in text + + def test_instructions_dedup_injects_when_content_diverges(self) -> None: + # Saved 'Alice' but store now has 'Alice (UPDATED)' (e.g., another agent updated it) + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice (UPDATED)')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._ctx_with_save('user_name', 'Alice')) + # Inject because current content differs from what the LLM saw saved + assert '[user_name] Alice (UPDATED)' in text + + def test_instructions_dedup_disabled_via_flag(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + cap: Memory[None] = Memory(store=store, dedup_recent_saves=False) + text = cap.build_instructions(self._ctx_with_save('user_name', 'Alice')) + # Always inject when dedup is off + assert '[user_name] Alice' in text + + def test_instructions_dedup_does_not_affect_pinned(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='persona', content='I am helpful', read_only=True)) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._ctx_with_save('persona', 'I am helpful')) + # Pinned entries always inject regardless of dedup + assert '[persona] I am helpful' in text + + def test_instructions_dedup_uses_most_recent_save(self) -> None: + # Two saves of the same key; only the most recent one's content matters for dedup. + from unittest.mock import MagicMock + + from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart + + store = DictMemoryStore() + store.put(MemoryEntry(key='user_pref', content='loves green')) + msgs: list[ModelMessage] = [ + ModelResponse( + parts=[ToolCallPart(tool_name='save_memory', args={'key': 'user_pref', 'content': 'loves blue'})] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='save_memory', args={'key': 'user_pref', 'content': 'loves green'})] + ), + ] + ctx: RunContext[None] = RunContext(deps=None, model=MagicMock(), usage=RunUsage(), messages=msgs) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(ctx) + # Most recent save matches current content → suppress + assert '[user_pref]' not in text + # --- MemoryStore protocol ---