diff --git a/.agents/settings.local.json b/.agents/settings.local.json deleted file mode 100644 index 8b311a3..0000000 --- a/.agents/settings.local.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "permissions": { - "allow": [] - } -} diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..3ffd0f7 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,106 @@ +# Memory Capability + +## Summary + +Implements a `Memory` capability (`AbstractCapability` subclass) providing +persistent key-value memory across agent sessions. References issues #30 and #31. + +User-facing docs: [`docs/capabilities/memory.md`](docs/capabilities/memory.md). + +## Design + +### Architecture + +- **`Memory`** dataclass extends `AbstractCapability[AgentDepsT]` + - `get_instructions()` returns a dynamic callable injecting stored memories + into the system prompt at run start + - `get_toolset()` returns a `FunctionToolset` with five tools: `save_memory`, + `recall_memory`, `search_memories`, `list_memories`, `delete_memory` + - Per-tool description overrides via `tool_descriptions: dict[str, str]` + - Tool functions use closures over `self.store` and `self.recency_scorer` + (no dependency on agent `deps`) + +### Storage + +- **`MemoryStore`** Protocol: pluggable backend with six methods — `get`, `put`, + `delete`, `list_all`, `search`, `list_namespaces` +- **`DictMemoryStore`**: dict-based, ephemeral, for tests/scratch (default) +- **`FileMemoryStore(path)`**: JSON file on disk, reads on init, writes on every + mutation; drops expired entries on save +- Both extend `_BaseDictStore` for shared logic +- Custom backends: implement the Protocol. See + [`examples/memory/postgres_store.py`](examples/memory/postgres_store.py) for + a Postgres reference. + +### `MemoryEntry` + +Required: `key`, `content`. Optional fields: + +- `tags: list[str]` — LLM-set categorisation +- `namespace: tuple[str, ...]` — hierarchical namespace, prefix-matched +- `expires_at: str | None` — ISO 8601 wall-clock expiry (opt-in TTL) +- `created_at`, `updated_at: str` — ISO 8601 timestamps +- `summary: str | None` — preferred over `content` for prompt injection +- `metadata: dict[str, object]` — structured attributes; filterable via search +- `read_only: bool` — pin against agent modification (always injected) +- `char_limit: int | None` — hard cap on `content` length, enforced at construction +- `importance: float | None` — search-score booster + +`MemoryEntryDict` TypedDict covers the JSON serialisation form. + +### Search & retrieval + +- `search` score = keyword-match count + `entry.importance` (if set) + + `recency_scorer(entry)` (if provided) +- Word-boundary regex matching across `key`, `content`, `tags` (case-insensitive) +- `_` and `-` count as word separators +- Default recency scorer: `exponential_decay(half_life_days=30, weight=0.5)` +- `search`/`list_all` accept `namespace` (prefix match) and `filter` (metadata + equality) kwargs; `search` additionally accepts `recency_scorer` + +### Instructions injection + +- `read_only=True` entries always inject (bypass count cap, byte budget, and dedup) +- Non-pinned entries respect `max_instructions_memories` (default 20) and + `byte_budget: int | None` (UTF-8 byte cap) +- `entry.summary` is preferred over `entry.content` to save tokens +- `dedup_recent_saves: bool = True` suppresses entries the LLM just saved in + this run's tool history, when the saved content still matches the store + entry (content-aware: if external state diverged, inject the current value) +- Pinned entries are listed first +- Disabled entirely via `inject_memories_in_instructions=False` (prompt-cache + mitigation for write-heavy workloads) + +### Spec serialisation + +- `Memory.get_serialization_name()` → `"Memory"` +- `Memory.from_spec(backend='memory'|'file', path=..., ...)` for declarative config + +## Configuration + +| Field | Default | Description | +|---|---|---| +| `store` | `DictMemoryStore()` | Storage backend | +| `inject_memories_in_instructions` | `True` | Include memories in system prompt | +| `max_instructions_memories` | `20` | Cap on non-pinned memories injected | +| `byte_budget` | `None` | Optional UTF-8 byte cap on injection block | +| `recency_scorer` | `exponential_decay(half_life_days=30, weight=0.5)` | Or `None` to disable | +| `tool_descriptions` | `{}` | Per-tool description overrides | +| `dedup_recent_saves` | `True` | Suppress injection of entries just saved in this run | + +## Files + +- `src/pydantic_harness/memory.py` — capability, stores, entry model, recency helpers +- `src/pydantic_harness/__init__.py` — re-exports +- `tests/test_memory.py` — 150 tests covering all code paths +- `examples/memory/*.py` — three runnable examples plus the Postgres reference +- `docs/capabilities/memory.md` — user-facing docs + +## Future Work + +- **Semantic retrieval** — `SemanticMemoryStore` Protocol extension and an + `EmbeddingStore` reference (numpy/cosine, or pgvector). Deferred until a + concrete backend drives the API design — premature design tends to lock in + the wrong shape. +- **Deferred capability loading** (PR #5230 in pydantic-ai) — once that lands, + declare `id`/`description` on `Memory` to opt into deferred loading. diff --git a/examples/memory/coding_assistant.py b/examples/memory/coding_assistant.py new file mode 100644 index 0000000..15a78dd --- /dev/null +++ b/examples/memory/coding_assistant.py @@ -0,0 +1,101 @@ +"""Self-Improving Coding Assistant — procedural memory via instructions injection. + +Demonstrates: instructions injection as self-modifying prompt, scoping, search, delete. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_ai_harness.memory import DictMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() # pyright: ignore[reportUnknownMemberType] + + +def main() -> None: + """Run the coding assistant example.""" + store = DictMemoryStore() + memory = Memory(store=store, max_instructions_memories=10) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a coding assistant that learns from user corrections. ' + 'When the user gives you a coding rule or correction, save it as a memory ' + 'with namespace ["rules"] and tags like ["python", "style"] or ["typescript", "testing"]. ' + 'Use descriptive keys like "rule_python_fstrings" or "rule_ts_const". ' + 'When asked to write code, search your memories for relevant rules first.' + ), + ) + + # --- Teach rules --- + with logfire.span('teach-rules'): + result1 = agent.run_sync( + 'Remember these coding rules:\n' + '1. Always use f-strings in Python, never .format() or % formatting\n' + '2. In TypeScript, prefer const over let, never use var\n' + '3. Always add type hints to Python function signatures' + ) + print(f'Assistant: {result1.output}') + + rules = store.list_all() + print(f'\nRules stored: {len(rules)}') + for r in rules: + print(f' [{r.key}] {r.content} (namespace={r.namespace}, tags={r.tags})') + + assert len(rules) >= 3, f'Expected at least 3 rules saved, got {len(rules)}' + + # Check that search works across stored rules + python_rules = store.search('python') + print(f'Rules matching "python": {len(python_rules)}') + assert len(python_rules) >= 1, 'Expected at least 1 rule matching "python"' + + # --- Verify instructions injection --- + # Build instructions should now include the rules + from unittest.mock import MagicMock + + from pydantic_ai._run_context import RunContext + from pydantic_ai.usage import RunUsage + + ctx: RunContext[None] = RunContext(deps=None, model=MagicMock(), usage=RunUsage()) + instructions = memory.build_instructions(ctx) + print(f'\nInstructions preview (first 300 chars):\n{instructions[:300]}...') + + assert 'Currently stored memories' in instructions, 'Expected memories in instructions' + + # --- Ask for code, verify rules are considered --- + with logfire.span('apply-rules'): + result2 = agent.run_sync( + 'Write a Python function that greets a user by name. Follow all coding rules you know.' + ) + print(f'\nAssistant: {result2.output}') + + # The output should use f-strings and type hints (based on rules) + output_lower = result2.output.lower() + assert "f'" in result2.output or 'f"' in result2.output or 'f-string' in output_lower, ( + 'Expected f-string usage in code output' + ) + + # --- Delete an obsolete rule --- + with logfire.span('delete-rule'): + result3 = agent.run_sync('Actually, the TypeScript const rule is outdated for this project. Delete it.') + print(f'\nAssistant: {result3.output}') + + remaining = store.list_all() + print(f'\nRules after deletion: {len(remaining)}') + for r in remaining: + print(f' [{r.key}] {r.content}') + + # Should have fewer rules now + assert len(remaining) < len(rules), 'Expected at least one rule deleted' + + print('\n--- Coding Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/personal_assistant.py b/examples/memory/personal_assistant.py new file mode 100644 index 0000000..753f46c --- /dev/null +++ b/examples/memory/personal_assistant.py @@ -0,0 +1,89 @@ +"""Personal Assistant — remembers user preferences across sessions. + +Demonstrates: FileMemoryStore persistence, save/recall, instructions injection, tags, scoping. +""" + +from __future__ import annotations + +import sys +import tempfile +from pathlib import Path + +import logfire +from pydantic_ai import Agent + +from pydantic_ai_harness.memory import FileMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() # pyright: ignore[reportUnknownMemberType] + + +def main() -> None: + """Run the personal assistant example.""" + with tempfile.TemporaryDirectory() as tmpdir: + mem_path = Path(tmpdir) / 'preferences.json' + store = FileMemoryStore(mem_path) + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a helpful personal assistant. ' + 'When the user tells you about their preferences, save each one as a memory ' + 'with namespace ["user_prefs"] and appropriate tags. ' + 'Use descriptive keys like "preferred_name" or "theme_preference".' + ), + ) + + # --- Session 1: user shares preferences --- + with logfire.span('session-1-save-preferences'): + result1 = agent.run_sync("Hi! My name is Alice, I prefer dark mode, and I'm vegetarian.") + print(f'Assistant: {result1.output}') + + entries = store.list_all() + print(f'\nMemories after session 1: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, namespace={e.namespace})') + + assert len(entries) >= 2, f'Expected at least 2 memories saved, got {len(entries)}' + all_content = ' '.join(e.content.lower() for e in entries) + assert 'alice' in all_content or any('alice' in e.key.lower() for e in entries), 'Expected a memory about Alice' + + # --- Session 2: new agent instance loads from same file (persistence) --- + store2 = FileMemoryStore(mem_path) + memory2 = Memory(store=store2) + agent2 = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory2], + system_prompt='You are a helpful personal assistant.', + ) + + loaded_entries = store2.list_all() + print(f'\nMemories loaded in session 2: {len(loaded_entries)}') + assert len(loaded_entries) == len(entries), 'FileMemoryStore persistence failed' + + with logfire.span('session-2-recall-preferences'): + result2 = agent2.run_sync('What do you know about me?') + print(f'Assistant: {result2.output}') + + # The instructions injection should have included the memories + assert 'alice' in result2.output.lower() or 'dark' in result2.output.lower(), ( + 'Expected assistant to recall preferences from instructions injection' + ) + + # --- Session 3: update a preference --- + with logfire.span('session-3-update-preference'): + result3 = agent2.run_sync('Actually, I go by Ali now. Please update my name.') + print(f'\nAssistant: {result3.output}') + + updated_entries = store2.list_all() + print(f'\nMemories after update: {len(updated_entries)}') + for e in updated_entries: + print(f' [{e.key}] {e.content} (tags={e.tags})') + + print('\n--- Personal Assistant example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/examples/memory/postgres_store.py b/examples/memory/postgres_store.py new file mode 100644 index 0000000..adac92a --- /dev/null +++ b/examples/memory/postgres_store.py @@ -0,0 +1,167 @@ +"""Postgres backend for the Memory capability — reference implementation. + +Shows how to implement `MemoryStore` against Postgres using `psycopg`. This is +a starting point, not a production-ready backend: adapt to your deployment +(connection pooling, schema migrations, async, full-text search via tsvector). + +For semantic retrieval, swap the `search` implementation for one that runs +`SELECT ... ORDER BY embedding <=> %s` against a pgvector column. + +Schema: + CREATE TABLE memories ( + key TEXT PRIMARY KEY, + namespace TEXT[] NOT NULL DEFAULT ARRAY['global'], + data JSONB NOT NULL + ); + +Usage: + pip install 'psycopg[binary]' + + import psycopg + from pydantic_ai import Agent + from pydantic_ai_harness.memory import Memory + from examples.memory.postgres_store import PostgresMemoryStore + + conn = psycopg.connect('postgresql://localhost/myapp') + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=PostgresMemoryStore(conn))]) +""" + +from __future__ import annotations + +import json +import re +from typing import Any + +import psycopg + +from pydantic_ai_harness.memory import MemoryEntry, RecencyScorer + +SCHEMA = """ +CREATE TABLE IF NOT EXISTS memories ( + key TEXT PRIMARY KEY, + namespace TEXT[] NOT NULL DEFAULT ARRAY['global'], + data JSONB NOT NULL +); +CREATE INDEX IF NOT EXISTS memories_namespace_idx ON memories USING GIN (namespace); +""" + + +class PostgresMemoryStore: + """`MemoryStore` backed by Postgres via psycopg. + + Implements the full `MemoryStore` Protocol: `get`, `put`, `delete`, + `list_all`, `search`, `list_namespaces`. Filtering happens in SQL + (namespace prefix via array slicing, metadata equality via JSONB ops); + keyword scoring runs in Python after the DB pre-filter. + """ + + def __init__(self, conn: psycopg.Connection[Any]) -> None: + self._conn = conn + with self._conn.cursor() as cur: + cur.execute(SCHEMA) + self._conn.commit() + + def get(self, key: str) -> MemoryEntry | None: + with self._conn.cursor() as cur: + cur.execute('SELECT data FROM memories WHERE key = %s', (key,)) + row = cur.fetchone() + if row is None: + return None + entry = MemoryEntry.from_dict(row[0]) + return None if entry.is_expired() else entry + + def put(self, entry: MemoryEntry) -> None: + with self._conn.cursor() as cur: + cur.execute( + 'INSERT INTO memories (key, namespace, data) VALUES (%s, %s, %s) ' + 'ON CONFLICT (key) DO UPDATE SET namespace = EXCLUDED.namespace, data = EXCLUDED.data', + (entry.key, list(entry.namespace), json.dumps(entry.to_dict())), + ) + self._conn.commit() + + def delete(self, key: str) -> bool: + with self._conn.cursor() as cur: + cur.execute('DELETE FROM memories WHERE key = %s', (key,)) + deleted = cur.rowcount > 0 + self._conn.commit() + return deleted + + def list_all( + self, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + ) -> list[MemoryEntry]: + sql = 'SELECT data FROM memories WHERE TRUE' + params: list[Any] = [] + if namespace is not None: + # Prefix match: entry.namespace[1:N] = supplied namespace tuple + sql += ' AND namespace[1:%s] = %s' + params.extend([len(namespace), list(namespace)]) + if filter is not None: + for k, v in filter.items(): + sql += " AND (data -> 'metadata' ->> %s) = %s" + params.extend([k, str(v)]) + with self._conn.cursor() as cur: + cur.execute(sql, params) + rows = cur.fetchall() + entries = [MemoryEntry.from_dict(r[0]) for r in rows] + return [e for e in entries if not e.is_expired()] + + def search( + self, + query: str, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + recency_scorer: RecencyScorer | None = None, + ) -> list[MemoryEntry]: + # SQL-side filter, Python-side keyword scoring matching DictMemoryStore semantics. + # Production: replace with full-text or pgvector ranking. + words = query.lower().split() + if not words: + return [] + candidates = self.list_all(namespace=namespace, filter=filter) + scored: list[tuple[float, MemoryEntry]] = [] + for entry in candidates: + base = 0 + for word in words: + pattern = re.compile(rf'(? list[tuple[str, ...]]: + with self._conn.cursor() as cur: + cur.execute('SELECT DISTINCT namespace FROM memories') + rows = cur.fetchall() + seen: set[tuple[str, ...]] = set() + for row in rows: + ns: tuple[str, ...] = tuple(row[0]) + if max_depth is not None: + ns = ns[:max_depth] + if prefix is not None and (len(ns) < len(prefix) or ns[: len(prefix)] != prefix): + continue + if suffix is not None and (len(ns) < len(suffix) or ns[-len(suffix) :] != suffix): + continue + seen.add(ns) + return sorted(seen) diff --git a/examples/memory/study_coach.py b/examples/memory/study_coach.py new file mode 100644 index 0000000..2c7188e --- /dev/null +++ b/examples/memory/study_coach.py @@ -0,0 +1,76 @@ +"""Study Coach — spaced repetition with TTL. + +Demonstrates: TTL/expiration, save with ttl_minutes, list/search, tags. +""" + +from __future__ import annotations + +import sys + +import logfire +from pydantic_ai import Agent + +from pydantic_ai_harness.memory import DictMemoryStore, Memory + +logfire.configure(send_to_logfire='if-token-present') +logfire.instrument_openai() # pyright: ignore[reportUnknownMemberType] + + +def main() -> None: + """Run the study coach example.""" + store = DictMemoryStore() + memory = Memory(store=store) + + agent = Agent( + 'openai:gpt-4o-mini', + capabilities=[memory], + system_prompt=( + 'You are a study coach that helps users learn facts. ' + 'When the user provides a fact to learn, save it as a memory with ' + 'tag "study" and a ttl_minutes value: use 1 for new/hard facts, ' + '60 for reviewed facts, and 1440 for mastered facts. ' + 'Use descriptive keys like "biology_mitochondria" or "history_magna_carta".' + ), + ) + + # --- Learn some facts --- + with logfire.span('learn-facts'): + result1 = agent.run_sync( + 'I need to learn these facts:\n' + '1. Mitochondria are the powerhouse of the cell\n' + '2. The Magna Carta was signed in 1215\n' + '3. Water boils at 100 degrees Celsius at sea level' + ) + print(f'Coach: {result1.output}') + + entries = store.list_all() + print(f'\nFacts stored: {len(entries)}') + for e in entries: + print(f' [{e.key}] {e.content} (tags={e.tags}, ttl={e.expires_at})') + + assert len(entries) >= 3, f'Expected at least 3 facts saved, got {len(entries)}' + + # Check that TTL was set on at least some entries + entries_with_ttl = [e for e in entries if e.expires_at is not None] + assert len(entries_with_ttl) >= 1, 'Expected at least 1 entry with TTL set' + print(f'Entries with TTL: {len(entries_with_ttl)}') + + # Check tags + entries_with_study_tag = [e for e in entries if 'study' in e.tags] + assert len(entries_with_study_tag) >= 1, 'Expected at least 1 entry with "study" tag' + + # --- Search for facts --- + with logfire.span('search-facts'): + result2 = agent.run_sync('Search my memories for anything about biology.') + print(f'\nCoach: {result2.output}') + + # --- List all facts --- + with logfire.span('list-facts'): + result3 = agent.run_sync('List all my study memories.') + print(f'\nCoach: {result3.output}') + + print('\n--- Study Coach example passed! ---') + + +if __name__ == '__main__': + sys.exit(main() or 0) diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 0a60fd7..e7881f5 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -4,8 +4,40 @@ if TYPE_CHECKING: from .code_mode import CodeMode + from .memory import ( + DictMemoryStore, + FileMemoryStore, + Memory, + MemoryEntry, + MemoryEntryDict, + MemoryStore, + RecencyScorer, + exponential_decay, + ) -__all__ = ['CodeMode'] +__all__ = [ + 'CodeMode', + 'DictMemoryStore', + 'FileMemoryStore', + 'Memory', + 'MemoryEntry', + 'MemoryEntryDict', + 'MemoryStore', + 'RecencyScorer', + 'exponential_decay', +] + + +_MEMORY_NAMES = { + 'DictMemoryStore', + 'FileMemoryStore', + 'Memory', + 'MemoryEntry', + 'MemoryEntryDict', + 'MemoryStore', + 'RecencyScorer', + 'exponential_decay', +} def __getattr__(name: str) -> object: @@ -13,4 +45,8 @@ def __getattr__(name: str) -> object: from .code_mode import CodeMode return CodeMode + if name in _MEMORY_NAMES: + from . import memory + + return getattr(memory, name) raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/pydantic_ai_harness/memory/README.md b/pydantic_ai_harness/memory/README.md new file mode 100644 index 0000000..b89f073 --- /dev/null +++ b/pydantic_ai_harness/memory/README.md @@ -0,0 +1,161 @@ +# Memory + +Persistent key-value memory across agent sessions. Provides five tools the LLM +can call (`save_memory`, `recall_memory`, `search_memories`, `list_memories`, +`delete_memory`) and injects currently-stored memories into the system prompt +each run. + +## Quick start + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import Memory + +agent = Agent('openai:gpt-4o', capabilities=[Memory()]) +``` + +By default `Memory()` uses an in-process `DictMemoryStore` — entries live only +for the lifetime of the Python process. Use `FileMemoryStore` for single-user +single-process persistence, or implement `MemoryStore` for anything else. + +## Built-in backends + +| Backend | Persistence | Concurrency | Use case | +|---|---|---|---| +| `DictMemoryStore` | None (in-process) | Single-thread | Tests, scratch agents | +| `FileMemoryStore(path)` | JSON file on disk | Single-process | Single-user CLI agents | + +For Postgres, Redis, vector DBs, etc. — implement the `MemoryStore` Protocol. +See [`examples/memory/postgres_store.py`](https://github.com/pydantic/pydantic-ai-harness/blob/main/examples/memory/postgres_store.py) +for a reference implementation. + +## `MemoryEntry` fields + +| Field | Type | Default | Notes | +|---|---|---|---| +| `key` | `str` | required | Unique identifier | +| `content` | `str` | required | The fact itself | +| `tags` | `list[str]` | `[]` | LLM-set categorisation | +| `namespace` | `tuple[str, ...]` | `('global',)` | Hierarchical namespace; prefix-matched in queries | +| `expires_at` | `str \| None` | `None` | ISO 8601 wall-clock expiry; opt-in TTL | +| `created_at`, `updated_at` | `str` | now | ISO 8601 timestamps | +| `summary` | `str \| None` | `None` | Short version preferred over `content` for prompt injection | +| `metadata` | `dict[str, object]` | `{}` | Structured attributes; filterable via `search(filter=...)` | +| `read_only` | `bool` | `False` | If True, agent's tools refuse to modify | +| `char_limit` | `int \| None` | `None` | Optional hard cap on `content` length (raises at construction) | +| `importance` | `float \| None` | `None` | Search-score booster | + +## Namespaces + +Hierarchical, tuple-based. Filters in `list_all`/`search` use prefix matching: + +```python +store.put(MemoryEntry(key='a', content='...', namespace=('users', 'alice'))) +store.put(MemoryEntry(key='b', content='...', namespace=('users', 'bob'))) +store.put(MemoryEntry(key='c', content='...', namespace=('agents', 'planner'))) + +store.list_all(namespace=('users',)) # → [a, b] +store.list_namespaces(prefix=('users',)) # → [('users', 'alice'), ('users', 'bob')] +``` + +## Multi-agent shared memory + +One store, two agents, separate namespaces: + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import FileMemoryStore, Memory + +shared = FileMemoryStore('/var/lib/myapp/memory.json') + +planner = Agent('openai:gpt-4o', capabilities=[ + Memory(store=shared, byte_budget=2000), +]) +worker = Agent('openai:gpt-4o-mini', capabilities=[ + Memory(store=shared, byte_budget=500), +]) +``` + +Entries written by either agent are visible to both. Use `namespace=('agents', +'planner')` etc. on saves to keep their workspaces separate while still sharing +common facts in `('global',)`. + +## Search + +Word-boundary regex on key, content, and tags. Final score: + +``` +keyword_match_count + (entry.importance or 0) + (recency_scorer(entry) or 0) +``` + +Recency boost is enabled by default — `Memory` ships with +`exponential_decay(half_life_days=30, weight=0.5)`. Override with any callable: + +```python +from pydantic_ai_harness import Memory, exponential_decay + +# Tighter half-life for fast-moving information +Memory(recency_scorer=exponential_decay(half_life_days=7)) + +# Custom: boost only entries with the 'pinned' tag +Memory(recency_scorer=lambda e: 1.0 if 'pinned' in e.tags else 0.0) + +# Disable recency entirely +Memory(recency_scorer=None) +``` + +## Prompt-cache trade-off + +Every save/delete changes the injected memories block in the system prompt, +invalidating the prompt-cache prefix. Read-heavy workloads keep the cache; +write-heavy workloads thrash. + +Mitigation: `Memory(inject_memories_in_instructions=False)` skips the +injection. The LLM reads memories only via explicit `list_memories` / +`search_memories` / `recall_memory` calls — system prompt prefix stays stable +across writes. + +For partial mitigation, set `byte_budget` to cap the injected block size. + +## Dedup against tool history + +Default `dedup_recent_saves=True` suppresses injection of an entry when the +LLM has already seen its current value via a `save_memory` call in this run's +tool history. Content-aware: if something updated the entry externally (e.g., +another agent), the saved content no longer matches the store, so the entry +is injected so the LLM sees the current value. `read_only=True` entries are +never suppressed. Disable with `dedup_recent_saves=False`. + +## Tool description overrides + +```python +Memory(tool_descriptions={ + 'save_memory': 'Save anything the user mentions about themselves, ' + 'even tiny details. Tag with "user_pref".', +}) +``` + +## Custom backends + +Implement the `MemoryStore` Protocol — six methods, all positional/kwarg-only: + +```python +from pydantic_ai_harness import MemoryEntry, MemoryStore, RecencyScorer + +class MyStore: + def get(self, key: str) -> MemoryEntry | None: ... + def put(self, entry: MemoryEntry) -> None: ... + def delete(self, key: str) -> bool: ... + def list_all(self, *, namespace=None, filter=None): ... + def search(self, query, *, namespace=None, filter=None, recency_scorer=None): ... + def list_namespaces(self, *, prefix=None, suffix=None, max_depth=None): ... +``` + +Drop into any `Memory(store=MyStore())`. See the Postgres example for a +working reference. + +## Known followups + +- **Semantic retrieval**: `SemanticMemoryStore` Protocol extension and an + `EmbeddingStore` reference impl. Deferred until a concrete backend (Qdrant / + pgvector / LanceDB) drives the API design. diff --git a/pydantic_ai_harness/memory/__init__.py b/pydantic_ai_harness/memory/__init__.py new file mode 100644 index 0000000..5f3b6d9 --- /dev/null +++ b/pydantic_ai_harness/memory/__init__.py @@ -0,0 +1,23 @@ +"""Memory capability: persistent key-value memory across agent sessions.""" + +from pydantic_ai_harness.memory._capability import ( + DictMemoryStore, + FileMemoryStore, + Memory, + MemoryEntry, + MemoryEntryDict, + MemoryStore, + RecencyScorer, + exponential_decay, +) + +__all__ = [ + 'DictMemoryStore', + 'FileMemoryStore', + 'Memory', + 'MemoryEntry', + 'MemoryEntryDict', + 'MemoryStore', + 'RecencyScorer', + 'exponential_decay', +] diff --git a/pydantic_ai_harness/memory/_capability.py b/pydantic_ai_harness/memory/_capability.py new file mode 100644 index 0000000..70d1779 --- /dev/null +++ b/pydantic_ai_harness/memory/_capability.py @@ -0,0 +1,815 @@ +"""Memory capability for persistent agent memory across sessions. + +Provides tools for saving, recalling, searching, listing, and deleting +key-value memories, with pluggable storage backends (`DictMemoryStore` for +testing, `FileMemoryStore` for on-disk persistence). +""" + +from __future__ import annotations + +import json +import logging +import re +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any, Protocol, TypeAlias, TypedDict, runtime_checkable + +from pydantic_ai._instructions import AgentInstructions +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart +from pydantic_ai.tools import AgentDepsT, RunContext, Tool +from pydantic_ai.toolsets import AgentToolset +from pydantic_ai.toolsets.function import FunctionToolset + +logger = logging.getLogger(__name__) + + +class _MemoryEntryDictRequired(TypedDict): + """Required fields for MemoryEntryDict.""" + + key: str + content: str + + +class MemoryEntryDict(_MemoryEntryDictRequired, total=False): + """Serialized form of a MemoryEntry for JSON storage. + + Only `key` and `content` are required; the remaining fields are + optional so that `from_dict` can accept legacy data missing some keys. + """ + + tags: list[str] + namespace: list[str] + expires_at: str | None + created_at: str + updated_at: str + summary: str | None + metadata: dict[str, object] + read_only: bool + char_limit: int | None + importance: float | None + + +@dataclass +class MemoryEntry: + """A single memory entry with content, tags, and timestamps.""" + + key: str + """Unique identifier for this memory.""" + + content: str + """The content of the memory.""" + + tags: list[str] = field(default_factory=lambda: list[str]()) + """Optional tags for categorization and search.""" + + namespace: tuple[str, ...] = ('global',) + """Hierarchical namespace for this memory. + + A tuple of strings forming a path-like namespace (e.g., `('users', 'alice')`, + `('agents', 'planner', 'facts')`). Filters in `list_all`/`search` use prefix + matching: `namespace=('users',)` matches `('users', 'alice')` and + `('users', 'bob')`. Default `('global',)` is a single-segment namespace. + """ + + expires_at: str | None = None + """Optional ISO 8601 expiration timestamp. `None` means no expiry.""" + + created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of when the memory was first created.""" + + updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + """ISO 8601 timestamp of the last update.""" + + summary: str | None = None + """Optional short summary used by `Memory.build_instructions` when injecting this entry into the system prompt; falls back to `content` when None.""" + + metadata: dict[str, object] = field(default_factory=lambda: dict[str, object]()) + """Structured attributes for filterable search (`MemoryStore.search(filter=...)`). Values must be JSON-serializable.""" + + read_only: bool = False + """If True, the agent's `save_memory` and `delete_memory` tools refuse to modify this entry. Programmatic access via the store is unrestricted.""" + + char_limit: int | None = None + """Optional hard cap on `content` length (chars). Enforced at `MemoryEntry` construction; raises `ValueError` if exceeded.""" + + importance: float | None = None + """Optional relevance booster used by `MemoryStore.search` scoring when set.""" + + def __post_init__(self) -> None: + """Validate `char_limit` immediately so dev errors surface at construction.""" + if self.char_limit is not None and len(self.content) > self.char_limit: + raise ValueError( + f'MemoryEntry {self.key!r} content is {len(self.content)} chars, exceeds char_limit={self.char_limit}', + ) + + def is_expired(self) -> bool: + """Return True if this entry has passed its expiration time. + + Wall-clock semantics: an entry created with `ttl_minutes=N` expires `N` + minutes after creation in real time, regardless of how many agent turns + or sessions have elapsed. TTL is opt-in (default `expires_at=None` = + no expiry) and intended for facts with a real-world lifetime + (verification codes, session credentials, etc.). + """ + if self.expires_at is None: + return False + return datetime.fromisoformat(self.expires_at) <= datetime.now(timezone.utc) + + def to_dict(self) -> MemoryEntryDict: + """Serialize to a plain dict for JSON storage.""" + return { + 'key': self.key, + 'content': self.content, + 'tags': self.tags, + 'namespace': list(self.namespace), + 'expires_at': self.expires_at, + 'created_at': self.created_at, + 'updated_at': self.updated_at, + 'summary': self.summary, + 'metadata': self.metadata, + 'read_only': self.read_only, + 'char_limit': self.char_limit, + 'importance': self.importance, + } + + @classmethod + def from_dict(cls, data: MemoryEntryDict) -> MemoryEntry: + """Deserialize from a plain dict.""" + return cls( + key=data['key'], + content=data['content'], + tags=data.get('tags', []), + namespace=tuple(data.get('namespace', ('global',))), + expires_at=data.get('expires_at'), + created_at=data.get('created_at', ''), + updated_at=data.get('updated_at', ''), + summary=data.get('summary'), + metadata=data.get('metadata', {}), + read_only=data.get('read_only', False), + char_limit=data.get('char_limit'), + importance=data.get('importance'), + ) + + +def _score_entry(entry: MemoryEntry, words: list[str]) -> int: + r"""Score a memory entry by counting word-boundary matches across fields. + + Each query word that appears as a whole word (case-insensitive) in the + key, content, or any tag contributes one point per field it appears in. + Underscores and hyphens are treated as word separators in addition to + the standard `\\b` boundaries. + """ + score = 0 + for word in words: + # Use a boundary pattern that also treats _ and - as separators. + escaped = re.escape(word) + pattern = re.compile(rf'(? RecencyScorer: + """Build a recency scorer with exponential decay over `entry.updated_at`. + + Args: + half_life_days: Age (in days) at which the decay value is halved. Default `30.0`. + weight: Multiplier applied to the decay value. Default `1.0`. + + Returns: + A `RecencyScorer` callable. Entries with an unparsable `updated_at` + return `0.0`; future-dated entries return the full `weight`. + """ + + def scorer(entry: MemoryEntry) -> float: + try: + updated = datetime.fromisoformat(entry.updated_at) + except ValueError: + return 0.0 + age_seconds = (datetime.now(timezone.utc) - updated).total_seconds() + if age_seconds < 0: + return weight + age_days = age_seconds / 86400 + return weight * (2 ** (-age_days / half_life_days)) + + return scorer + + +def _saves_in_history(messages: list[ModelMessage]) -> dict[str, str]: + """Scan tool history for `save_memory` calls; return `{key: last saved content}`. + + Used by `Memory.build_instructions` to suppress re-injecting entries the LLM + just saved — when the saved content still matches the current store entry, + the LLM has already seen the value via the tool call/result in history, + so re-injecting it wastes tokens. If a key was saved multiple times, the + most recent save wins. + """ + last: dict[str, str] = {} + for msg in messages: + if not isinstance(msg, ModelResponse): + continue + for part in msg.parts: + if not isinstance(part, ToolCallPart): + continue + if part.tool_name != 'save_memory': + continue + args = part.args_as_dict() + key = args.get('key') + content = args.get('content') + if isinstance(key, str) and isinstance(content, str): + last[key] = content + return last + + +def _matches_filter(entry: MemoryEntry, filter_: dict[str, object]) -> bool: + """Return True if all filter keys match `entry.metadata` values exactly.""" + for key, value in filter_.items(): + if entry.metadata.get(key) != value: + return False + return True + + +def _namespace_matches(entry_ns: tuple[str, ...], filter_prefix: tuple[str, ...]) -> bool: + """Return True if `entry_ns` starts with `filter_prefix`. + + Empty `filter_prefix` matches any namespace (use `None` in callers to mean + "no filter" — this helper assumes the caller has decided to filter). + """ + if len(entry_ns) < len(filter_prefix): + return False + return entry_ns[: len(filter_prefix)] == filter_prefix + + +def _simple_similarity(a: str, b: str) -> bool: + """Return True if two keys share the same first 10 characters and differ only slightly. + + Uses a simple character-level edit distance check: keys are considered + similar when they share the same 10-char prefix and differ by at most 2 + characters (Levenshtein-like). + """ + if len(a) < 10 or len(b) < 10: + return False + if a[:10] != b[:10]: + return False + if a == b: + return False + # Simple Levenshtein-like check: allow at most 2 edits + if abs(len(a) - len(b)) > 2: + return False + # Bounded character-level distance (sufficient for dedup warnings) + max_edits = 2 + m, n = len(a), len(b) + prev = list(range(n + 1)) + for i in range(1, m + 1): + curr = [i] + [0] * n + for j in range(1, n + 1): + cost = 0 if a[i - 1] == b[j - 1] else 1 + curr[j] = min(curr[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost) + prev = curr + return prev[n] <= max_edits + + +@runtime_checkable +class MemoryStore(Protocol): + """Protocol for pluggable memory storage backends.""" + + def get(self, key: str) -> MemoryEntry | None: # pragma: no cover + """Retrieve a memory entry by key, or None if not found.""" + ... + + def put(self, entry: MemoryEntry) -> None: # pragma: no cover + """Store or update a memory entry.""" + ... + + def delete(self, key: str) -> bool: # pragma: no cover + """Delete a memory entry by key. Returns True if it existed.""" + ... + + def list_all( # pragma: no cover + self, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + ) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by namespace prefix and metadata equality.""" + ... + + def search( # pragma: no cover + self, + query: str, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + recency_scorer: RecencyScorer | None = None, + ) -> list[MemoryEntry]: + """Search non-expired entries, sorted by relevance. + + Score = keyword-match count + entry.importance (if set) + recency_scorer(entry) (if provided). + Entries with zero keyword match are excluded regardless of recency or importance. + """ + ... + + def list_namespaces( # pragma: no cover + self, + *, + prefix: tuple[str, ...] | None = None, + suffix: tuple[str, ...] | None = None, + max_depth: int | None = None, + ) -> list[tuple[str, ...]]: + """List unique namespaces among non-expired entries, optionally filtered. + + Args: + prefix: Only include namespaces starting with this prefix. + suffix: Only include namespaces ending with this suffix. + max_depth: Truncate each namespace to at most this many segments before deduplication. + """ + ... + + +class _BaseDictStore: + """Base class for dict-backed memory stores.""" + + _entries: dict[str, MemoryEntry] + + def get(self, key: str) -> MemoryEntry | None: + """Retrieve a non-expired memory entry by key.""" + entry = self._entries.get(key) + if entry is None or entry.is_expired(): + return None + return entry + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + self._entries[entry.key] = entry + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + return self._entries.pop(key, None) is not None + + def _gc_expired(self) -> None: + """Drop expired entries from the backing dict.""" + expired_keys = [key for key, entry in self._entries.items() if entry.is_expired()] + for key in expired_keys: + del self._entries[key] + + def list_all( + self, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + ) -> list[MemoryEntry]: + """Return all non-expired entries, optionally filtered by namespace prefix and metadata equality.""" + return [ + entry + for entry in self._entries.values() + if not entry.is_expired() + and (namespace is None or _namespace_matches(entry.namespace, namespace)) + and (filter is None or _matches_filter(entry, filter)) + ] + + def search( + self, + query: str, + *, + namespace: tuple[str, ...] | None = None, + filter: dict[str, object] | None = None, + recency_scorer: RecencyScorer | None = None, + ) -> list[MemoryEntry]: + """Search non-expired entries with word-boundary matching, sorted by relevance. + + Score = keyword-match count + `entry.importance` (if set) + `recency_scorer(entry)` (if provided). + """ + words = query.lower().split() + if not words: + return [] + scored: list[tuple[float, MemoryEntry]] = [] + for entry in self._entries.values(): + if entry.is_expired(): + continue + if namespace is not None and not _namespace_matches(entry.namespace, namespace): + continue + if filter is not None and not _matches_filter(entry, filter): + continue + base = _score_entry(entry, words) + if base == 0: + continue + score: float = float(base) + if entry.importance is not None: + score += entry.importance + if recency_scorer is not None: + score += recency_scorer(entry) + scored.append((score, entry)) + scored.sort(key=lambda pair: pair[0], reverse=True) + return [entry for _, entry in scored] + + def list_namespaces( + self, + *, + prefix: tuple[str, ...] | None = None, + suffix: tuple[str, ...] | None = None, + max_depth: int | None = None, + ) -> list[tuple[str, ...]]: + """List unique namespaces among non-expired entries.""" + seen: set[tuple[str, ...]] = set() + for entry in self._entries.values(): + if entry.is_expired(): + continue + ns = entry.namespace + if max_depth is not None: + ns = ns[:max_depth] + if prefix is not None and not _namespace_matches(ns, prefix): + continue + if suffix is not None and (len(ns) < len(suffix) or ns[-len(suffix) :] != suffix): + continue + seen.add(ns) + return sorted(seen) + + +class DictMemoryStore(_BaseDictStore): + """Dict-based in-memory store, suitable for testing. + + All data lives in a plain `dict` and is lost when the process exits. + """ + + def __init__(self) -> None: + """Initialize an empty in-memory store.""" + self._entries: dict[str, MemoryEntry] = {} + + +class FileMemoryStore(_BaseDictStore): + """JSON-file-based store for simple on-disk persistence. + + Reads the file on initialization and writes back on every mutation. + """ + + def __init__(self, path: str | Path) -> None: + """Initialize a file-backed store at the given path.""" + self._path = Path(path) + self._entries: dict[str, MemoryEntry] = {} + self._load() + + def _load(self) -> None: + if self._path.exists(): + try: + raw: dict[str, MemoryEntryDict] = json.loads(self._path.read_text(encoding='utf-8')) + if not isinstance(raw, dict): # pyright: ignore[reportUnnecessaryIsInstance] + logger.warning('Memory file %s contains non-dict JSON, starting empty', self._path) + return + self._entries = {key: MemoryEntry.from_dict(val) for key, val in raw.items()} + except (json.JSONDecodeError, KeyError, TypeError) as e: + logger.warning('Failed to load memory file %s: %s, starting empty', self._path, e) + self._entries = {} + + def _save(self) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + self._gc_expired() + data = {key: entry.to_dict() for key, entry in self._entries.items()} + self._path.write_text(json.dumps(data, indent=2), encoding='utf-8') + + def put(self, entry: MemoryEntry) -> None: + """Store or update a memory entry.""" + super().put(entry) + self._save() + + def delete(self, key: str) -> bool: + """Delete a memory entry by key.""" + existed = super().delete(key) + if existed: + self._save() + return existed + + +def format_entry(entry: MemoryEntry, *, prefer_summary: bool = False) -> str: + """Format a memory entry as a human-readable string. + + Args: + entry: The entry to format. + prefer_summary: If True and `entry.summary` is set, render the summary + in place of the full content. Used by `Memory.build_instructions` + to keep system-prompt injection short. Defaults to False (full content). + """ + body = entry.summary if (prefer_summary and entry.summary is not None) else entry.content + line = f'[{entry.key}] {body}' + extras: list[str] = [] + if entry.tags: + extras.append(f'tags: {", ".join(entry.tags)}') + if entry.namespace != ('global',): + extras.append(f'namespace: {"/".join(entry.namespace)}') + if entry.expires_at is not None: + extras.append(f'expires: {entry.expires_at}') + if extras: + line += f' ({"; ".join(extras)})' + return line + + +@dataclass +class Memory(AbstractCapability[AgentDepsT]): + """Capability for persistent memory across agent sessions. + + Provides tools for saving, recalling, searching, listing, and deleting + key-value memories. Uses a pluggable `MemoryStore` backend for storage. + + Example: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_ai_harness.memory import Memory, DictMemoryStore + + agent = Agent('openai:gpt-4o', capabilities=[Memory(store=DictMemoryStore())]) + ``` + + Multi-agent shared store: + ```python {test="skip" lint="skip"} + from pydantic_ai import Agent + from pydantic_ai_harness.memory import FileMemoryStore, Memory + + shared = FileMemoryStore('/var/lib/myapp/memory.json') + planner = Agent('openai:gpt-4o', capabilities=[Memory(store=shared, byte_budget=2000)]) + worker = Agent('openai:gpt-4o-mini', capabilities=[Memory(store=shared, byte_budget=500)]) + ``` + Both agents see the same entries; use distinct `namespace` tuples on + saves to keep their workspaces separate (e.g., `('agents', 'planner')` + vs `('agents', 'worker')`). + """ + + store: MemoryStore = field(default_factory=DictMemoryStore) + """The storage backend. Defaults to `DictMemoryStore` (ephemeral, dict-based).""" + + inject_memories_in_instructions: bool = True + """Whether to inject existing memories into the system prompt at run start.""" + + max_instructions_memories: int = 20 + """Maximum number of non-pinned memories to include in the system prompt. + + `read_only=True` entries always inject regardless of this cap. + """ + + byte_budget: int | None = None + """Optional UTF-8 byte cap on the injected memories block. + + When set, non-pinned entries are skipped once adding the next would exceed + the budget. `read_only=True` entries always inject regardless of this cap. + Default `None` = no byte cap (only the count cap applies). + """ + + recency_scorer: RecencyScorer | None = field( + default_factory=lambda: exponential_decay(half_life_days=30.0, weight=0.5), + ) + """Recency scorer threaded into `search_memories` to bias results toward fresher entries. + + Defaults to `exponential_decay(half_life_days=30, weight=0.5)`. Set to `None` to disable. + Pass any `Callable[[MemoryEntry], float]` for custom decay shapes. + """ + + tool_descriptions: dict[str, str] = field(default_factory=lambda: dict[str, str]()) + """Per-tool description overrides. Keys are tool names (`save_memory`, `recall_memory`, + `search_memories`, `list_memories`, `delete_memory`); values replace the docstring used + by the LLM. Useful for nudging the agent (e.g., "Save aggressively, even small facts").""" + + dedup_recent_saves: bool = True + """When True, suppress injection of entries that match a `save_memory` call + in the current run's tool history (the LLM has already seen the value). + + The check is content-aware: if the store entry's `content` differs from the + most recent saved content (e.g., another process updated the entry), the + entry is injected so the LLM sees the current value. `read_only=True` + entries are never suppressed. + """ + + @classmethod + def get_serialization_name(cls) -> str | None: + """Return the name used for spec serialization.""" + return 'Memory' + + @classmethod + def from_spec( + cls, + *, + backend: str = 'memory', + path: str = '.memories.json', + inject_memories_in_instructions: bool = True, + max_instructions_memories: int = 20, + ) -> Memory[Any]: + """Create from spec arguments. + + Args: + backend: Storage backend, `"memory"` (default) or `"file"`. + path: File path for the `"file"` backend (default `".memories.json"`). + inject_memories_in_instructions: Whether to inject memories into the system prompt. + max_instructions_memories: Maximum memories to inject into the system prompt. + """ + store: MemoryStore + if backend == 'memory': + store = DictMemoryStore() + elif backend == 'file': + store = FileMemoryStore(path) + else: + raise ValueError(f'Unknown memory backend: {backend!r}. Use "memory" or "file".') + return cls( + store=store, + inject_memories_in_instructions=inject_memories_in_instructions, + max_instructions_memories=max_instructions_memories, + ) + + def build_instructions(self, ctx: RunContext[AgentDepsT]) -> str: + """Build dynamic instructions that include currently stored memories. + + Selection rules: + - `read_only=True` entries always inject (bypass count cap, byte budget, and dedup). + - Non-pinned entries respect `max_instructions_memories` and `byte_budget`. + - When `entry.summary` is set, it's preferred over `entry.content` to save tokens. + - When `dedup_recent_saves` is True, entries whose current content matches + the most recent `save_memory` call in this run's tool history are suppressed + (the LLM has already seen the value via the tool call). + - Pinned entries are listed first. + """ + parts: list[str] = [ + 'You have access to a persistent memory system. ' + 'Use it to save important information that should be remembered across conversations.', + ] + if not self.inject_memories_in_instructions: + return '\n'.join(parts) + + entries = self.store.list_all() + if not entries: + return '\n'.join(parts) + + parts.append('\nCurrently stored memories:') + + recent_saves: dict[str, str] = _saves_in_history(ctx.messages) if self.dedup_recent_saves else {} + + # Pinned first, then the rest in store order + ordered = sorted(entries, key=lambda e: not e.read_only) + + formatted: list[str] = [] + used_bytes = 0 + consumed_non_pinned = 0 + for entry in ordered: + # read_only entries bypass dedup, count cap, and byte budget + if not entry.read_only: + saved_content = recent_saves.get(entry.key) + if saved_content is not None and saved_content == entry.content: + continue + line = f'- {format_entry(entry, prefer_summary=True)}' + line_bytes = len(line.encode('utf-8')) + if entry.read_only: + formatted.append(line) + used_bytes += line_bytes + continue + if consumed_non_pinned >= self.max_instructions_memories: + break + if self.byte_budget is not None and used_bytes + line_bytes > self.byte_budget: + break + formatted.append(line) + used_bytes += line_bytes + consumed_non_pinned += 1 + + parts.extend(formatted) + + overflow = len(entries) - len(formatted) + if overflow > 0: + parts.append(f'... and {overflow} more (use list_memories or search_memories to see all).') + return '\n'.join(parts) + + def get_instructions(self) -> AgentInstructions[AgentDepsT] | None: + """Return dynamic instructions that include stored memories.""" + return self.build_instructions + + def get_toolset(self) -> AgentToolset[AgentDepsT] | None: + """Return a toolset with memory management tools. + + Tool functions close over `self` to access the store without + requiring anything from the agent's `deps`. + """ + store = self.store + recency_scorer = self.recency_scorer + + def save_memory( + key: str, + content: str, + tags: list[str] | None = None, + namespace: list[str] | None = None, + ttl_minutes: int | None = None, + summary: str | None = None, + importance: float | None = None, + ) -> str: + """Save or update a memory entry. + + Args: + key: Unique key for this memory. + content: The content to remember. + tags: Optional tags for categorization and search. + namespace: Optional hierarchical namespace as a list of segments + (e.g., `['users', 'alice']`). Defaults to `['global']`. + ttl_minutes: Optional time-to-live in minutes. The entry will expire after this duration. + summary: Optional short summary; preferred over `content` when injecting into the system prompt. + importance: Optional relevance booster (e.g., 0.0–1.0); used by search scoring. + """ + now = datetime.now(timezone.utc) + now_iso = now.isoformat() + existing = store.get(key) + + if existing is not None and existing.read_only: + return f'Memory {key!r} is read-only and cannot be modified.' + + # Dedup warning: check for similar keys among existing entries + for existing_entry in store.list_all(): + if _simple_similarity(key, existing_entry.key): + logger.warning( + 'New memory key %r is very similar to existing key %r — possible duplicate', + key, + existing_entry.key, + ) + + expires_at: str | None = None + if ttl_minutes is not None: + expires_at = (now + timedelta(minutes=ttl_minutes)).isoformat() + + ns: tuple[str, ...] = tuple(namespace) if namespace else ('global',) + entry = MemoryEntry( + key=key, + content=content, + tags=tags or [], + namespace=ns, + expires_at=expires_at, + created_at=existing.created_at if existing else now_iso, + updated_at=now_iso, + summary=summary, + importance=importance, + ) + store.put(entry) + return f'Memory saved: {key}' + + def recall_memory(key: str) -> str: + """Recall a specific memory by its key. + + Args: + key: The key of the memory to recall. + """ + entry = store.get(key) + if entry is None: + return f'No memory found for key: {key}' + return format_entry(entry) + + def search_memories(query: str, namespace: list[str] | None = None) -> str: + """Search memories by word-boundary matching on keys, content, or tags, sorted by relevance. + + Args: + query: The search query string (space-separated words). + namespace: Optional namespace prefix to restrict the search to (e.g., `['users']`). + """ + ns: tuple[str, ...] | None = tuple(namespace) if namespace else None + results = store.search(query, namespace=ns, recency_scorer=recency_scorer) + if not results: + return f'No memories found matching: {query}' + return '\n'.join(format_entry(entry) for entry in results) + + def list_memories(namespace: list[str] | None = None) -> str: + """List all stored memories, optionally filtered by namespace prefix. + + Args: + namespace: Optional namespace prefix to filter by (e.g., `['users']`). + """ + ns: tuple[str, ...] | None = tuple(namespace) if namespace else None + entries = store.list_all(namespace=ns) + if not entries: + return 'No memories stored.' + return '\n'.join(format_entry(entry) for entry in entries) + + def delete_memory(key: str) -> str: + """Delete a memory by its key. + + Args: + key: The key of the memory to delete. + """ + entry = store.get(key) + if entry is not None and entry.read_only: + return f'Memory {key!r} is read-only and cannot be deleted.' + if store.delete(key): + return f'Memory deleted: {key}' + return f'No memory found for key: {key}' + + descs = self.tool_descriptions + return FunctionToolset( + [ + Tool(save_memory, takes_ctx=False, description=descs.get('save_memory')), + Tool(recall_memory, takes_ctx=False, description=descs.get('recall_memory')), + Tool(search_memories, takes_ctx=False, description=descs.get('search_memories')), + Tool(list_memories, takes_ctx=False, description=descs.get('list_memories')), + Tool(delete_memory, takes_ctx=False, description=descs.get('delete_memory')), + ], + ) diff --git a/pyproject.toml b/pyproject.toml index 002da51..87fbdd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ extend-select = ['Q', 'RUF100', 'C90', 'UP', 'I', 'D', 'TID251'] [tool.ruff.lint.per-file-ignores] 'tests/**/*.py' = ['D'] +'examples/**/*.py' = ['D'] [tool.ruff.lint.flake8-quotes] inline-quotes = 'single' diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..11e8034 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,1180 @@ +"""Tests for the Memory capability.""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Any + +from pydantic_ai._run_context import RunContext +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import RunUsage + +from pydantic_ai_harness.memory import ( + DictMemoryStore, + FileMemoryStore, + Memory, + MemoryEntry, + MemoryStore, + exponential_decay, +) +from pydantic_ai_harness.memory._capability import ( + _score_entry, + _simple_similarity, + format_entry, +) + +# --- MemoryEntry --- + + +class TestMemoryEntry: + def test_round_trip(self) -> None: + entry = MemoryEntry( + key='k', + content='v', + tags=['a', 'b'], + namespace=('project',), + expires_at='2099-01-01T00:00:00+00:00', + created_at='t1', + updated_at='t2', + ) + assert MemoryEntry.from_dict(entry.to_dict()) == entry + + def test_from_dict_defaults(self) -> None: + entry = MemoryEntry.from_dict({'key': 'k', 'content': 'v'}) + assert entry.tags == [] + assert entry.namespace == ('global',) + assert entry.expires_at is None + assert entry.created_at == '' + assert entry.updated_at == '' + + def test_default_timestamps(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.created_at # non-empty ISO string + assert entry.updated_at + + def test_default_namespace(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.namespace == ('global',) + + def test_is_expired_no_expiry(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert not entry.is_expired() + + def test_is_expired_future(self) -> None: + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=future) + assert not entry.is_expired() + + def test_is_expired_past(self) -> None: + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + entry = MemoryEntry(key='k', content='v', expires_at=past) + assert entry.is_expired() + + def test_default_new_fields(self) -> None: + entry = MemoryEntry(key='k', content='v') + assert entry.summary is None + assert entry.metadata == {} + assert entry.read_only is False + assert entry.char_limit is None + assert entry.importance is None + + def test_round_trip_with_new_fields(self) -> None: + entry = MemoryEntry( + key='k', + content='v', + summary='short', + metadata={'priority': 1, 'source': 'manual'}, + read_only=True, + char_limit=100, + importance=0.8, + ) + assert MemoryEntry.from_dict(entry.to_dict()) == entry + + def test_char_limit_enforced(self) -> None: + import pytest + + with pytest.raises(ValueError, match='exceeds char_limit'): + MemoryEntry(key='k', content='hello world', char_limit=5) + + def test_char_limit_allows_exact(self) -> None: + # Exactly at the limit is allowed + MemoryEntry(key='k', content='hello', char_limit=5) + + +# --- _score_entry --- + + +class TestScoreEntry: + def test_no_match(self) -> None: + entry = MemoryEntry(key='greeting', content='hello world') + assert _score_entry(entry, ['zzz']) == 0 + + def test_key_match(self) -> None: + entry = MemoryEntry(key='greeting', content='some text') + assert _score_entry(entry, ['greeting']) == 1 + + def test_content_match(self) -> None: + entry = MemoryEntry(key='k', content='hello world') + assert _score_entry(entry, ['hello']) == 1 + + def test_tag_match(self) -> None: + entry = MemoryEntry(key='k', content='text', tags=['important']) + assert _score_entry(entry, ['important']) == 1 + + def test_multiple_field_match(self) -> None: + entry = MemoryEntry(key='hello', content='hello world', tags=['hello']) + # 'hello' appears in key (1) + content (1) + tags (1) = 3 + assert _score_entry(entry, ['hello']) == 3 + + def test_multiple_words(self) -> None: + entry = MemoryEntry(key='user', content='Alice likes blue') + # 'alice' in content (1), 'blue' in content (1) = 2 + assert _score_entry(entry, ['alice', 'blue']) == 2 + + def test_word_boundary_no_partial(self) -> None: + # 'fox' should NOT match 'foxes' with word-boundary matching + entry = MemoryEntry(key='k', content='foxes jump') + assert _score_entry(entry, ['fox']) == 0 + + def test_regex_metacharacters_in_query(self) -> None: + entry = MemoryEntry(key='lang', content='I use c++ daily') + assert _score_entry(entry, ['c++']) == 1 + + def test_empty_words_list(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert _score_entry(entry, []) == 0 + + def test_underscore_word_boundary(self) -> None: + entry = MemoryEntry(key='user_name', content='text') + assert _score_entry(entry, ['name']) == 1 + + def test_hyphen_word_boundary(self) -> None: + entry = MemoryEntry(key='my-project', content='text') + assert _score_entry(entry, ['project']) == 1 + + def test_partial_word_match(self) -> None: + entry = MemoryEntry(key='k', content='alice likes blue') + # 'alice' matches (1), 'zzz' does not (0) = score 1 + assert _score_entry(entry, ['alice', 'zzz']) == 1 + + +# --- _simple_similarity --- + + +class TestSimpleSimilarity: + def test_identical_keys_not_similar(self) -> None: + assert not _simple_similarity('abcdefghij', 'abcdefghij') + + def test_short_keys_not_similar(self) -> None: + assert not _simple_similarity('abc', 'abd') + + def test_similar_long_keys(self) -> None: + # Differ by 2 characters ('fo' vs 'ba') — within the edit-distance threshold + assert _simple_similarity('abcdefghij_fo', 'abcdefghij_ba') + + def test_different_prefix(self) -> None: + assert not _simple_similarity('xxxxxxxxxxfoo', 'yyyyyyyyyyfoo') + + def test_same_prefix_large_edit(self) -> None: + assert not _simple_similarity('abcdefghijklmnop', 'abcdefghijzzzzzz') + + def test_length_diff_too_large(self) -> None: + # Same 10-char prefix but length differs by more than 2 + assert not _simple_similarity('abcdefghij_x', 'abcdefghij_xyzw') + + def test_one_char_diff(self) -> None: + assert _simple_similarity('abcdefghij_x', 'abcdefghij_y') + + def test_edit_distance_exactly_three(self) -> None: + # Just over the threshold -- should NOT be similar + assert not _simple_similarity('abcdefghij_abc', 'abcdefghij_xyz') + + def test_nine_char_keys(self) -> None: + # Just below the 10-char minimum + assert not _simple_similarity('abcdefghi', 'abcdefghj') + + def test_exactly_ten_char_keys_not_similar(self) -> None: + # 10-char keys differing at position 10 do NOT share a 10-char prefix + assert not _simple_similarity('abcdefghij', 'abcdefghik') + + +# --- DictMemoryStore --- + + +class TestDictMemoryStore: + def test_put_and_get(self) -> None: + store = DictMemoryStore() + entry = MemoryEntry(key='greeting', content='hello') + store.put(entry) + assert store.get('greeting') is entry + + def test_get_missing(self) -> None: + store = DictMemoryStore() + assert store.get('nope') is None + + def test_put_overwrites(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k', content='v1')) + store.put(MemoryEntry(key='k', content='v2')) + result = store.get('k') + assert result is not None + assert result.content == 'v2' + + def test_delete_existing(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.delete('k') is True + assert store.get('k') is None + + def test_delete_missing(self) -> None: + store = DictMemoryStore() + assert store.delete('nope') is False + + def test_list_all_empty(self) -> None: + store = DictMemoryStore() + assert store.list_all() == [] + + def test_list_all(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + entries = store.list_all() + assert len(entries) == 2 + assert {e.key for e in entries} == {'a', 'b'} + + def test_list_all_filters_expired(self) -> None: + store = DictMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='fresh')) + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + entries = store.list_all() + assert len(entries) == 1 + assert entries[0].key == 'alive' + + def test_get_filters_expired(self) -> None: + store = DictMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + assert store.get('dead') is None + + def test_list_all_scope_filter(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('project',))) + store.put(MemoryEntry(key='b', content='y', namespace=('global',))) + entries = store.list_all(namespace=('project',)) + assert len(entries) == 1 + assert entries[0].key == 'a' + + def test_list_all_scope_none_returns_all(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('project',))) + store.put(MemoryEntry(key='b', content='y', namespace=('global',))) + assert len(store.list_all(namespace=None)) == 2 + + def test_search_by_key(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + store.put(MemoryEntry(key='color', content='blue')) + results = store.search('user') + assert len(results) == 1 + assert results[0].key == 'user_name' + + def test_search_by_content(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k1', content='the quick brown fox')) + store.put(MemoryEntry(key='k2', content='lazy dog')) + results = store.search('fox') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_by_tag(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k1', content='x', tags=['important'])) + store.put(MemoryEntry(key='k2', content='y', tags=['trivial'])) + results = store.search('important') + assert len(results) == 1 + assert results[0].key == 'k1' + + def test_search_case_insensitive(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='K1', content='Hello World')) + results = store.search('hello') + assert len(results) == 1 + + def test_search_no_results(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('zzz') == [] + + def test_search_filters_expired(self) -> None: + store = DictMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + results = store.search('hello') + assert len(results) == 1 + assert results[0].key == 'alive' + + def test_search_scope_filter(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', namespace=('project',))) + store.put(MemoryEntry(key='b', content='hello world', namespace=('global',))) + results = store.search('hello', namespace=('project',)) + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_relevance_ordering(self) -> None: + store = DictMemoryStore() + # 'hello' appears in key + content = score 2 + store.put(MemoryEntry(key='hello', content='hello there')) + # 'hello' appears only in content = score 1 + store.put(MemoryEntry(key='other', content='hello world')) + results = store.search('hello') + assert len(results) == 2 + assert results[0].key == 'hello' # higher score first + assert results[1].key == 'other' + + def test_search_empty_query(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + + def test_list_namespaces(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice'))) + store.put(MemoryEntry(key='b', content='y', namespace=('users', 'bob'))) + store.put(MemoryEntry(key='c', content='z', namespace=('agents', 'planner'))) + assert store.list_namespaces() == [('agents', 'planner'), ('users', 'alice'), ('users', 'bob')] + + def test_list_namespaces_prefix(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice'))) + store.put(MemoryEntry(key='b', content='y', namespace=('agents', 'planner'))) + assert store.list_namespaces(prefix=('users',)) == [('users', 'alice')] + + def test_list_namespaces_max_depth(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice', 'prefs'))) + store.put(MemoryEntry(key='b', content='y', namespace=('users', 'bob', 'prefs'))) + # Truncate to depth 1 → both collapse to ('users',) + assert store.list_namespaces(max_depth=1) == [('users',)] + + def test_list_namespaces_suffix(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice', 'prefs'))) + store.put(MemoryEntry(key='b', content='y', namespace=('agents', 'planner', 'prefs'))) + assert store.list_namespaces(suffix=('prefs',)) == [ + ('agents', 'planner', 'prefs'), + ('users', 'alice', 'prefs'), + ] + + def test_list_all_namespace_prefix_match(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', namespace=('users', 'alice'))) + store.put(MemoryEntry(key='b', content='y', namespace=('users', 'bob'))) + store.put(MemoryEntry(key='c', content='z', namespace=('agents',))) + # Prefix ('users',) matches both alice and bob + results = store.list_all(namespace=('users',)) + assert {e.key for e in results} == {'a', 'b'} + + def test_list_all_with_filter(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='x', metadata={'priority': 1})) + store.put(MemoryEntry(key='b', content='y', metadata={'priority': 2})) + results = store.list_all(filter={'priority': 1}) + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_with_filter(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', metadata={'source': 'manual'})) + store.put(MemoryEntry(key='b', content='hello world', metadata={'source': 'auto'})) + results = store.search('hello', filter={'source': 'manual'}) + assert len(results) == 1 + assert results[0].key == 'a' + + def test_search_filter_no_match(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='a', content='hello world', metadata={'source': 'manual'})) + assert store.search('hello', filter={'source': 'nonexistent'}) == [] + + def test_search_importance_boosts(self) -> None: + store = DictMemoryStore() + # Both match 'hello' once in content; importance differentiates them + store.put(MemoryEntry(key='boring', content='hello there')) + store.put(MemoryEntry(key='vip', content='hello there', importance=2.0)) + results = store.search('hello') + assert results[0].key == 'vip' + assert results[1].key == 'boring' + + def test_search_with_recency_scorer(self) -> None: + store = DictMemoryStore() + # Identical keyword scores; recent entry should rank first via recency_scorer. + old = (datetime.now(timezone.utc) - timedelta(days=365)).isoformat() + new = datetime.now(timezone.utc).isoformat() + store.put(MemoryEntry(key='ancient', content='hello there', updated_at=old)) + store.put(MemoryEntry(key='fresh', content='hello there', updated_at=new)) + results = store.search('hello', recency_scorer=exponential_decay(half_life_days=30.0)) + assert results[0].key == 'fresh' + assert results[1].key == 'ancient' + + +class TestExponentialDecay: + def test_fresh_entry_full_weight(self) -> None: + scorer = exponential_decay(half_life_days=30.0, weight=1.0) + entry = MemoryEntry(key='k', content='v') # updated_at = now + # Should be very close to 1.0 (essentially zero seconds elapsed) + assert 0.99 < scorer(entry) <= 1.0 + + def test_one_half_life_old(self) -> None: + scorer = exponential_decay(half_life_days=30.0, weight=1.0) + thirty_days_ago = (datetime.now(timezone.utc) - timedelta(days=30)).isoformat() + entry = MemoryEntry(key='k', content='v', updated_at=thirty_days_ago) + # ~0.5 (within float-precision tolerance) + assert 0.49 < scorer(entry) < 0.51 + + def test_invalid_updated_at_returns_zero(self) -> None: + scorer = exponential_decay() + entry = MemoryEntry(key='k', content='v', updated_at='') + assert scorer(entry) == 0.0 + + def test_weight_multiplier(self) -> None: + scorer = exponential_decay(half_life_days=30.0, weight=0.5) + entry = MemoryEntry(key='k', content='v') # fresh + assert 0.49 < scorer(entry) <= 0.5 + + +# --- FileMemoryStore --- + + +class TestFileMemoryStore: + def test_put_and_get(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.get('k') is not None + assert store.get('k').content == 'v' # type: ignore[union-attr] + + def test_persistence(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileMemoryStore(path) + store1.put(MemoryEntry(key='k', content='persisted')) + + # New store instance should load from disk + store2 = FileMemoryStore(path) + result = store2.get('k') + assert result is not None + assert result.content == 'persisted' + + def test_delete_saves(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='k', content='v')) + store.delete('k') + + # Reload and verify deletion persisted + store2 = FileMemoryStore(path) + assert store2.get('k') is None + + def test_delete_missing(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + assert store.delete('nope') is False + + def test_list_all(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='a', content='alpha')) + store.put(MemoryEntry(key='b', content='beta')) + assert len(store.list_all()) == 2 + + def test_search(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='k1', content='hello', tags=['greeting'])) + store.put(MemoryEntry(key='k2', content='world')) + assert len(store.search('greeting')) == 1 + assert len(store.search('hello')) == 1 + + def test_empty_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + # File does not exist yet + store = FileMemoryStore(path) + assert store.list_all() == [] + + def test_creates_parent_dirs(self, tmp_path: Path) -> None: + path = tmp_path / 'sub' / 'dir' / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert path.exists() + + def test_file_format(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='k', content='v', tags=['t'], created_at='c', updated_at='u')) + raw = json.loads(path.read_text()) + assert raw == { + 'k': { + 'key': 'k', + 'content': 'v', + 'tags': ['t'], + 'namespace': ['global'], + 'expires_at': None, + 'created_at': 'c', + 'updated_at': 'u', + 'summary': None, + 'metadata': {}, + 'read_only': False, + 'char_limit': None, + 'importance': None, + } + } + + def test_list_all_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='x')) + store.put(MemoryEntry(key='dead', content='y', expires_at=past)) + assert len(store.list_all()) == 1 + + def test_search_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='alive', content='hello world')) + store.put(MemoryEntry(key='dead', content='hello world', expires_at=past)) + assert len(store.search('hello')) == 1 + + def test_list_all_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='a', content='x', namespace=('project',))) + store.put(MemoryEntry(key='b', content='y', namespace=('global',))) + assert len(store.list_all(namespace=('project',))) == 1 + + def test_search_scope(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='a', content='hello world', namespace=('project',))) + store.put(MemoryEntry(key='b', content='hello world', namespace=('global',))) + assert len(store.search('hello', namespace=('project',))) == 1 + + def test_search_empty_query(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + store.put(MemoryEntry(key='k', content='v')) + assert store.search('') == [] + + def test_load_malformed_json(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('not json at all', encoding='utf-8') + store = FileMemoryStore(path) + assert store.list_all() == [] + + def test_load_wrong_structure(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('["a", "b"]', encoding='utf-8') + store = FileMemoryStore(path) + assert store.list_all() == [] + + def test_load_missing_entry_fields(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + path.write_text('{"k": {"not_a_key": "oops"}}', encoding='utf-8') + store = FileMemoryStore(path) + assert store.list_all() == [] + + def test_namespace_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileMemoryStore(path) + store1.put(MemoryEntry(key='k', content='v', namespace=('session',))) + store2 = FileMemoryStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.namespace == ('session',) + + def test_nested_namespace_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store1 = FileMemoryStore(path) + store1.put(MemoryEntry(key='k', content='v', namespace=('users', 'alice', 'prefs'))) + store2 = FileMemoryStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.namespace == ('users', 'alice', 'prefs') + + def test_expires_at_persists(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + future = (datetime.now(timezone.utc) + timedelta(hours=1)).isoformat() + store1 = FileMemoryStore(path) + store1.put(MemoryEntry(key='k', content='v', expires_at=future)) + store2 = FileMemoryStore(path) + entry = store2.get('k') + assert entry is not None + assert entry.expires_at == future + + def test_save_drops_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='dead', content='stale', expires_at=past)) + store.put(MemoryEntry(key='alive', content='fresh')) + + raw = json.loads(path.read_text(encoding='utf-8')) + assert 'dead' not in raw + assert 'alive' in raw + + def test_get_filters_expired(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + store = FileMemoryStore(path) + future = (datetime.now(timezone.utc) + timedelta(seconds=1)).isoformat() + store.put(MemoryEntry(key='soon', content='v', expires_at=future)) + # Manually backdate by mutating the in-memory entry's expires_at + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store._entries['soon'].expires_at = past + assert store.get('soon') is None + + +# --- format_entry --- + + +class TestFormatEntry: + def test_no_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert format_entry(entry) == '[k] hello' + + def test_with_tags(self) -> None: + entry = MemoryEntry(key='k', content='hello', tags=['a', 'b']) + assert format_entry(entry) == '[k] hello (tags: a, b)' + + def test_with_namespace(self) -> None: + entry = MemoryEntry(key='k', content='hello', namespace=('project',)) + assert format_entry(entry) == '[k] hello (namespace: project)' + + def test_prefer_summary_uses_summary(self) -> None: + entry = MemoryEntry(key='k', content='long content body', summary='short') + assert format_entry(entry, prefer_summary=True) == '[k] short' + + def test_prefer_summary_falls_back_to_content(self) -> None: + entry = MemoryEntry(key='k', content='hello') + assert format_entry(entry, prefer_summary=True) == '[k] hello' + + def test_default_prefers_content(self) -> None: + entry = MemoryEntry(key='k', content='full content', summary='short') + # prefer_summary defaults to False + assert format_entry(entry) == '[k] full content' + + def test_with_nested_namespace(self) -> None: + entry = MemoryEntry(key='k', content='hello', namespace=('users', 'alice')) + assert format_entry(entry) == '[k] hello (namespace: users/alice)' + + def test_global_namespace_omitted(self) -> None: + entry = MemoryEntry(key='k', content='hello', namespace=('global',)) + assert format_entry(entry) == '[k] hello' + + def test_with_expires_at(self) -> None: + entry = MemoryEntry(key='k', content='hello', expires_at='2099-01-01T00:00:00+00:00') + assert format_entry(entry) == '[k] hello (expires: 2099-01-01T00:00:00+00:00)' + + def test_all_extras(self) -> None: + entry = MemoryEntry( + key='k', + content='hello', + tags=['t'], + namespace=('project',), + expires_at='2099-01-01T00:00:00+00:00', + ) + assert format_entry(entry) == '[k] hello (tags: t; namespace: project; expires: 2099-01-01T00:00:00+00:00)' + + def test_empty_content(self) -> None: + entry = MemoryEntry(key='k', content='') + assert format_entry(entry) == '[k] ' + + def test_empty_key(self) -> None: + entry = MemoryEntry(key='', content='hello') + assert format_entry(entry) == '[] hello' + + +# --- Memory capability --- + + +class TestMemoryCapability: + def test_serialization_name(self) -> None: + assert Memory.get_serialization_name() == 'Memory' + + def test_from_spec_default(self) -> None: + cap = Memory.from_spec() + assert isinstance(cap.store, DictMemoryStore) + + def test_from_spec_file(self, tmp_path: Path) -> None: + path = tmp_path / 'mem.json' + cap = Memory.from_spec(backend='file', path=str(path)) + assert isinstance(cap.store, FileMemoryStore) + + def test_from_spec_unknown_backend(self) -> None: + import pytest + + with pytest.raises(ValueError, match='Unknown memory backend'): + Memory.from_spec(backend='redis') + + def test_from_spec_explicit_memory_backend(self) -> None: + cap = Memory.from_spec(backend='memory') + assert isinstance(cap.store, DictMemoryStore) + + def test_from_spec_with_options(self, tmp_path: Path) -> None: + cap = Memory.from_spec( + backend='file', + path=str(tmp_path / 'mem.json'), + inject_memories_in_instructions=False, + max_instructions_memories=10, + ) + assert isinstance(cap.store, FileMemoryStore) + assert cap.inject_memories_in_instructions is False + assert cap.max_instructions_memories == 10 + + def test_default_store(self) -> None: + cap: Memory[None] = Memory() + assert isinstance(cap.store, DictMemoryStore) + + def test_get_toolset_returns_function_toolset(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + + def test_toolset_has_expected_tools(self) -> None: + cap: Memory[None] = Memory() + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + tool_names = set(toolset.tools.keys()) + assert tool_names == {'save_memory', 'recall_memory', 'search_memories', 'list_memories', 'delete_memory'} + + def test_tool_descriptions_override(self) -> None: + custom = 'CUSTOM: Save aggressively. Tiny facts count.' + cap: Memory[None] = Memory(tool_descriptions={'save_memory': custom}) + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + assert toolset.tools['save_memory'].description == custom + # Other tools fall back to docstring (not the override) + recall_desc = toolset.tools['recall_memory'].description + assert recall_desc is not None and 'Recall' in recall_desc + + +# --- Tool functions (via closure) --- + + +class TestMemoryTools: + """Test the tool functions exposed by the Memory capability.""" + + @staticmethod + def _get_tools(store: DictMemoryStore | None = None) -> dict[str, Any]: + cap: Memory[None] = Memory(store=store or DictMemoryStore()) + toolset = cap.get_toolset() + assert isinstance(toolset, FunctionToolset) + return {name: tool.function for name, tool in toolset.tools.items()} + + def test_save_and_recall(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + result = tools['save_memory']('greeting', 'hello world') + assert result == 'Memory saved: greeting' + + recalled = tools['recall_memory']('greeting') + assert '[greeting] hello world' in recalled + + def test_recall_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['recall_memory']('nope') + + def test_recall_expired(self) -> None: + store = DictMemoryStore() + past = (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat() + store.put(MemoryEntry(key='old', content='stale', expires_at=past)) + tools = self._get_tools(store) + assert 'No memory found' in tools['recall_memory']('old') + + def test_save_updates_existing(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v1') + original = store.get('k') + assert original is not None + original_created = original.created_at + + tools['save_memory']('k', 'v2') + updated = store.get('k') + assert updated is not None + assert updated.content == 'v2' + # created_at should be preserved + assert updated.created_at == original_created + + def test_save_with_tags(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', ['tag1', 'tag2']) + entry = store.get('k') + assert entry is not None + assert entry.tags == ['tag1', 'tag2'] + + def test_save_with_namespace(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, ['project']) + entry = store.get('k') + assert entry is not None + assert entry.namespace == ('project',) + + def test_save_with_nested_namespace(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, ['users', 'alice']) + entry = store.get('k') + assert entry is not None + assert entry.namespace == ('users', 'alice') + + def test_save_with_ttl(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, ['global'], 60) + entry = store.get('k') + assert entry is not None + assert entry.expires_at is not None + expires = datetime.fromisoformat(entry.expires_at) + # Should expire roughly 60 minutes from now + assert expires > datetime.now(timezone.utc) + timedelta(minutes=59) + assert expires < datetime.now(timezone.utc) + timedelta(minutes=61) + + def test_search(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('user_name', 'Alice') + tools['save_memory']('color', 'blue') + + result = tools['search_memories']('Alice') + assert 'Alice' in result + assert 'blue' not in result + + def test_search_no_results(self) -> None: + tools = self._get_tools() + assert 'No memories found' in tools['search_memories']('zzz') + + def test_search_with_scope(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'hello world', None, ['project']) + tools['save_memory']('b', 'hello world', None, ['global']) + result = tools['search_memories']('hello', ['project']) + assert '[a]' in result + assert '[b]' not in result + + def test_list_empty(self) -> None: + tools = self._get_tools() + assert tools['list_memories']() == 'No memories stored.' + + def test_list_with_entries(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha') + tools['save_memory']('b', 'beta') + result = tools['list_memories']() + assert '[a] alpha' in result + assert '[b] beta' in result + + def test_list_with_scope(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('a', 'alpha', None, ['project']) + tools['save_memory']('b', 'beta', None, ['global']) + result = tools['list_memories'](['project']) + assert '[a] alpha' in result + assert '[b]' not in result + + def test_delete_existing(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v') + assert tools['delete_memory']('k') == 'Memory deleted: k' + assert store.get('k') is None + + def test_delete_missing(self) -> None: + tools = self._get_tools() + assert 'No memory found' in tools['delete_memory']('nope') + + def test_save_with_ttl_zero(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'v', None, ['global'], 0) + # TTL=0 expires immediately; get() filters it out + assert store.get('k') is None + # recall_memory should likewise report no memory + assert 'No memory found' in tools['recall_memory']('k') + + def test_save_with_summary_and_importance(self) -> None: + store = DictMemoryStore() + tools = self._get_tools(store) + tools['save_memory']('k', 'long content here', None, ['global'], None, 'short', 0.9) + entry = store.get('k') + assert entry is not None + assert entry.summary == 'short' + assert entry.importance == 0.9 + + def test_save_refuses_read_only(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='persona', content='locked', read_only=True)) + tools = self._get_tools(store) + result = tools['save_memory']('persona', 'overwrite attempt') + assert 'read-only' in result.lower() + # Original content preserved + entry = store.get('persona') + assert entry is not None + assert entry.content == 'locked' + + def test_delete_refuses_read_only(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='persona', content='locked', read_only=True)) + tools = self._get_tools(store) + result = tools['delete_memory']('persona') + assert 'read-only' in result.lower() + assert store.get('persona') is not None + + +# --- Dedup warning --- + + +class TestDedupWarning: + def test_similar_key_logs_warning(self, caplog: Any) -> None: + store = DictMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abcdefghij_x', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_ai_harness.memory'): + tools['save_memory']('abcdefghij_y', 'second value') + assert any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_different_keys_no_warning(self, caplog: Any) -> None: + store = DictMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('first_key_long', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_ai_harness.memory'): + tools['save_memory']('other_key_long', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + def test_short_keys_no_warning(self, caplog: Any) -> None: + store = DictMemoryStore() + tools = TestMemoryTools._get_tools(store) + tools['save_memory']('abc', 'first value') + with caplog.at_level(logging.WARNING, logger='pydantic_ai_harness.memory'): + tools['save_memory']('abd', 'second value') + assert not any('possible duplicate' in record.message.lower() for record in caplog.records) + + +# --- Instructions --- + + +class TestMemoryInstructions: + @staticmethod + def _make_ctx() -> RunContext[None]: + from unittest.mock import MagicMock + + return RunContext( + deps=None, + model=MagicMock(), + usage=RunUsage(), + ) + + def test_get_instructions_is_callable(self) -> None: + cap: Memory[None] = Memory() + assert callable(cap.get_instructions()) + + def test_instructions_with_no_memories(self) -> None: + cap: Memory[None] = Memory() + text = cap.build_instructions(self._make_ctx()) + assert 'persistent memory system' in text + assert 'Currently stored memories' not in text + + def test_instructions_with_memories(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user', content='Alice')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' in text + assert '[user] Alice' in text + + def test_instructions_respects_max(self) -> None: + store = DictMemoryStore() + for i in range(25): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and 20 more' in text + + def test_instructions_disabled(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='k', content='v')) + cap: Memory[None] = Memory(store=store, inject_memories_in_instructions=False) + text = cap.build_instructions(self._make_ctx()) + assert 'Currently stored memories' not in text + + def test_instructions_exact_max_no_overflow(self) -> None: + store = DictMemoryStore() + for i in range(5): + store.put(MemoryEntry(key=f'k{i}', content=f'v{i}')) + cap: Memory[None] = Memory(store=store, max_instructions_memories=5) + text = cap.build_instructions(self._make_ctx()) + assert '... and' not in text + assert '[k0]' in text + assert '[k4]' in text + + def test_instructions_use_summary(self) -> None: + store = DictMemoryStore() + store.put( + MemoryEntry(key='user', content='Alice is a 30-year-old data scientist', summary='Alice, data scientist') + ) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert '[user] Alice, data scientist' in text + assert '30-year-old' not in text + + def test_instructions_falls_back_to_content_without_summary(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user', content='Alice is a 30-year-old data scientist')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + assert 'Alice is a 30-year-old data scientist' in text + + def test_instructions_byte_budget_truncates(self) -> None: + store = DictMemoryStore() + for i in range(10): + store.put(MemoryEntry(key=f'k{i}', content=f'content for entry {i}')) + # Tight budget: only ~2 short lines fit + cap: Memory[None] = Memory(store=store, byte_budget=80) + text = cap.build_instructions(self._make_ctx()) + assert '[k0]' in text + # k9 (last) almost certainly excluded by byte budget + assert '[k9]' not in text + assert 'more (use list_memories' in text + + def test_instructions_pinned_always_injected(self) -> None: + store = DictMemoryStore() + # 5 normal entries + 1 pinned, with max=2: pinned is always shown + for i in range(5): + store.put(MemoryEntry(key=f'normal{i}', content=f'v{i}')) + store.put(MemoryEntry(key='persona', content='I am a helpful assistant', read_only=True)) + cap: Memory[None] = Memory(store=store, max_instructions_memories=2) + text = cap.build_instructions(self._make_ctx()) + assert '[persona] I am a helpful assistant' in text + # Two normal entries + the pinned one = 3 displayed + normal_displayed = sum(1 for i in range(5) if f'[normal{i}]' in text) + assert normal_displayed == 2 + + def test_instructions_pinned_bypasses_byte_budget(self) -> None: + store = DictMemoryStore() + # Big pinned entry that would overflow a tight budget + store.put(MemoryEntry(key='persona', content='x' * 500, read_only=True)) + store.put(MemoryEntry(key='normal', content='small')) + cap: Memory[None] = Memory(store=store, byte_budget=50) + text = cap.build_instructions(self._make_ctx()) + assert '[persona]' in text + # Pinned takes the budget; normal gets dropped + assert '[normal]' not in text + + def test_instructions_pinned_listed_first(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='zzz_normal', content='v')) + store.put(MemoryEntry(key='aaa_pinned', content='p', read_only=True)) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._make_ctx()) + # Pinned line appears before normal line + assert text.index('[aaa_pinned]') < text.index('[zzz_normal]') + + @staticmethod + def _ctx_with_save(key: str, content: str) -> RunContext[None]: + from unittest.mock import MagicMock + + from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart + + msgs: list[ModelMessage] = [ + ModelResponse(parts=[ToolCallPart(tool_name='save_memory', args={'key': key, 'content': content})]), + ] + return RunContext(deps=None, model=MagicMock(), usage=RunUsage(), messages=msgs) + + def test_instructions_dedup_suppresses_recently_saved(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._ctx_with_save('user_name', 'Alice')) + # LLM has already seen 'Alice' via the save_memory call in tool history + assert '[user_name]' not in text + + def test_instructions_dedup_injects_when_content_diverges(self) -> None: + # Saved 'Alice' but store now has 'Alice (UPDATED)' (e.g., another agent updated it) + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice (UPDATED)')) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._ctx_with_save('user_name', 'Alice')) + # Inject because current content differs from what the LLM saw saved + assert '[user_name] Alice (UPDATED)' in text + + def test_instructions_dedup_disabled_via_flag(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='user_name', content='Alice')) + cap: Memory[None] = Memory(store=store, dedup_recent_saves=False) + text = cap.build_instructions(self._ctx_with_save('user_name', 'Alice')) + # Always inject when dedup is off + assert '[user_name] Alice' in text + + def test_instructions_dedup_does_not_affect_pinned(self) -> None: + store = DictMemoryStore() + store.put(MemoryEntry(key='persona', content='I am helpful', read_only=True)) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(self._ctx_with_save('persona', 'I am helpful')) + # Pinned entries always inject regardless of dedup + assert '[persona] I am helpful' in text + + def test_instructions_dedup_uses_most_recent_save(self) -> None: + # Two saves of the same key; only the most recent one's content matters for dedup. + from unittest.mock import MagicMock + + from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart + + store = DictMemoryStore() + store.put(MemoryEntry(key='user_pref', content='loves green')) + msgs: list[ModelMessage] = [ + ModelResponse( + parts=[ToolCallPart(tool_name='save_memory', args={'key': 'user_pref', 'content': 'loves blue'})] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='save_memory', args={'key': 'user_pref', 'content': 'loves green'})] + ), + ] + ctx: RunContext[None] = RunContext(deps=None, model=MagicMock(), usage=RunUsage(), messages=msgs) + cap: Memory[None] = Memory(store=store) + text = cap.build_instructions(ctx) + # Most recent save matches current content → suppress + assert '[user_pref]' not in text + + +# --- MemoryStore protocol --- + + +class TestMemoryStoreProtocol: + def test_in_memory_store_satisfies_protocol(self) -> None: + assert isinstance(DictMemoryStore(), MemoryStore) + + def test_file_store_satisfies_protocol(self, tmp_path: Path) -> None: + assert isinstance(FileMemoryStore(tmp_path / 'mem.json'), MemoryStore) + + +# --- AbstractCapability conformance --- + + +class TestAbstractCapabilityConformance: + def test_is_abstract_capability_subclass(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert issubclass(Memory, AbstractCapability) + + def test_instance_is_abstract_capability(self) -> None: + from pydantic_ai.capabilities.abstract import AbstractCapability + + assert isinstance(Memory(), AbstractCapability)