diff --git a/cli/Dockerfile b/cli/Dockerfile new file mode 100644 index 0000000000..9c261f55e8 --- /dev/null +++ b/cli/Dockerfile @@ -0,0 +1,6 @@ +FROM python:3.12-slim-bookworm + +COPY backend/packages/harness /harness +RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple /harness + +CMD ["tail", "-f", "/dev/null"] diff --git a/cli/README_zh.md b/cli/README_zh.md new file mode 100644 index 0000000000..f41f4aeb30 --- /dev/null +++ b/cli/README_zh.md @@ -0,0 +1,348 @@ +# DeerFlow CLI + +DeerFlow CLI 是 DeerFlow AI 代理系统的命令行接口,提供完整的会话管理、持久化存储、流式响应和工具集成能力。采用**每个会话独立SQLite数据库**的架构设计,旨在解决全局锁竞争和状态污染问题。 + +## 核心特性 + +- **会话隔离**:每个会话拥有独立的SQLite数据库,无全局锁 +- **检查点保留**:所有执行步骤持久化,支持行为审计 +- **异步持久化**:后台线程处理文件写入,不阻塞主事件循环 +- **多会话管理**:创建、切换、删除、归档、恢复会话 +- **会话导出**:导出为 Markdown 格式 +- **会话搜索**:在所有会话中搜索关键词 +- **文件管理**:支持文件上传、列出和删除 +- **模型与技能**:动态切换模型,启用/禁用技能 +- **运行模式**:计划模式、子代理模式开关 +- **诊断系统**:工具调用分析、状态监控、递归限制设置 +- **错误处理**:故障排查指引 + +## 项目结构 + +``` +cli/ +├── cli.py # 命令行接口 +├── engine.py # 核心引擎实现 +├── session_store.py # 会话存储实现 +├── Dockerfile # Docker构建文件 +├── docker-compose.yaml # Docker Compose配置 +└── __init__.py # 模块初始化 +``` + +运行时数据目录: + +``` +.deer-flow/ +└── deerflow_sessions/ + ├── archive/ # 归档会话目录 + ├── .json # 会话元数据文件 + ├── _checkpoints.db # 会话数据库 + └── / # 导出文件目录 + └── export_.md # 导出的Markdown文件 +``` + +## 快速开始 + +### 前提条件 + +- Python 3.12+ +- Docker 和 Docker Compose (可选) + +### 本地运行 + +首先在项目根目录配置环境变量(包含模型、技能、MCP工具等配置): + +```bash +cd deer-flow +make config +``` + +**配置提示:** +- **模型配置**:参考 config.example.yaml,配置在 config.yaml 中 +- **技能配置**:在 `skills/` 目录下添加或修改技能配置文件 +- **MCP工具配置**:参考 extensions_config.example.json,配置在 extensions_config.json 中 + +然后安装 harness 包(开发模式)并运行 CLI: + +```bash +cd backend/packages/harness +pip install -e . + +cd ../../../cli +python cli.py +``` + +### Docker 运行 + +首先在宿主机配置环境变量(包含模型、技能、MCP工具等配置): + +```bash +cd deer-flow +make config +``` + +**配置提示:** +- **模型配置**:参考 config.example.yaml,配置在 config.yaml 中 +- **技能配置**:在 `skills/` 目录下添加或修改技能配置文件 +- **MCP工具配置**:参考 extensions_config.example.json,配置在 extensions_config.json 中 + +然后构建并运行容器: + +```bash +cd cli +docker compose build +docker compose up -d +docker compose exec app bash -c "cd /deer-flow && python cli/cli.py" +``` + +## 环境配置 + +### 离线环境 + +如遇 CLI 卡在 tiktoken 加载,可预缓存编码文件: + +```bash +# 下载、计算 hash、重命名 +mkdir -p ~/.tiktoken_cache +curl -L https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken \ + -o ~/.tiktoken_cache/cl100k_base.tiktoken +hash=$(sha1sum ~/.tiktoken_cache/cl100k_base.tiktoken | cut -d' ' -f1) +mv ~/.tiktoken_cache/cl100k_base.tiktoken ~/.tiktoken_cache/${hash}.tiktoken +``` + +Docker 挂载(在 docker-compose.yaml 添加): +```yaml +volumes: + - ${HOME}/.tiktoken_cache:/root/.tiktoken_cache:ro +environment: + - TIKTOKEN_CACHE_DIR=/root/.tiktoken_cache +``` + +验证离线可用: +```bash +# 本地 +TIKTOKEN_CACHE_DIR=~/.tiktoken_cache python -c " +import tiktoken +tiktoken.get_encoding('cl100k_base') +print('✓ 离线缓存可用') +" + +# Docker +docker compose exec app python -c " +import tiktoken +tiktoken.get_encoding('cl100k_base') +print('✓ 离线缓存可用') +" +``` + +## 使用说明 + +### 基本交互 + +启动后直接输入问题即可与AI代理对话: + +``` +====================================================================== +DeerFlow Production Engine - 本地测试模式 +====================================================================== +Type !help to see all available commands | 输入 !help 查看所有可用命令 +Type !multi to enter multi-line input mode | 输入 !multi 进入多行输入模式 +====================================================================== + +[abcdef12] You: 你好 +AI: 你好!我是DeerFlow AI助手,有什么可以帮助你的? + +[Metrics] Tokens: 42 | Tool Calls: 0 +``` + +### 命令列表 + +#### 会话管理 + +| 命令 | 说明 | +|------|------| +| `!new [id] [title]` | 创建新会话 | +| `!switch ` | 切换到指定会话 | +| `!delete session ` | 删除指定会话 | +| `!rename ` | 重命名当前会话 | +| `!archive <id>` | 归档指定会话 | +| `!archives` | 列出所有归档会话 | +| `!restore <id>` | 从归档恢复会话 | +| `!sessions` | 列出所有活动会话 | + +#### 调试诊断 + +| 命令 | 说明 | +|------|------| +| `!steps` | 查看当前会话的步骤列表(去重) | +| `!steps_all` | 查看全部检查点(包含无新内容的检查点) | +| `!diagnose` | 分析工具调用模式,检测潜在循环 | +| `!status` | 显示当前会话状态和运行时配置 | +| `!search <keyword>` | 在所有会话中搜索关键词 | + +#### 文件管理 + +| 命令 | 说明 | +|------|------| +| `!upload <path>` | 上传文件到当前会话 | +| `!files` | 列出当前会话的所有上传文件 | +| `!delete <filename>` | 删除指定的上传文件 | + +#### 模型与技能 + +| 命令 | 说明 | +|------|------| +| `!models` | 列出所有可用模型 | +| `!use <model>` | 切换到指定模型 | +| `!skills` | 列出所有可用技能 | +| `!enable <skill>` | 启用指定技能 | +| `!disable <skill>` | 禁用指定技能 | + +#### 运行模式 + +| 命令 | 说明 | +|------|------| +| `!plan on/off` | 开启/关闭计划模式 | +| `!subagent on/off` | 开启/关闭子代理委托 | +| `!recursion_limit <N>` | 设置递归限制(默认:1000) | + +#### 记忆系统 + +| 命令 | 说明 | +|------|------| +| `!memory` | 查看当前会话的记忆 | +| `!clear` | 清空当前会话的记忆 | + +#### 其他 + +| 命令 | 说明 | +|------|------| +| `!export` | 导出当前会话为Markdown | +| `!export_all` | 导出全部检查点为Markdown | +| `!multi` | 进入多行输入模式 | +| `!help` | 显示帮助信息 | +| `!exit` | 退出系统 | + +### 多行输入模式 + +``` +[abcdef12] You: !multi + +[abcdef12] Multi-line Input Mode | 多行输入模式 +Enter !end to finish multi-line input | 输入 !end 结束多行输入 + +这是第一行 +这是第二行 +这是第三行 +!end + +AI: 我收到了你的多行输入,内容是: +这是第一行 +这是第二行 +这是第三行 +``` + +### 诊断功能 + +#### 工具调用诊断 + +``` +[abcdef12] You: !diagnose + +[Tool Call Diagnostics | 工具调用诊断] + Session: abcdef12 + Total checkpoints: 42 + +[Tool Call Frequency Comparison | 工具调用频率对比] + Tool Name | Unique | Raw (with dup) | Ratio + ----------------------------------------+----------+-----------------+-------- + Read | 15 | 120 | 8.0x + Bash | 8 | 64 | 8.0x + Grep | 3 | 24 | 8.0x + +[Checkpoint Density Analysis | 检查点密度分析] + Unique tool calls: 26 + Raw occurrences across all checkpoints: 208 + Average duplications per unique call: 8.0x + ⚠️ High duplication - each tool call appears in many checkpoints + This is normal for long-running sessions with subagents + +[Potential Loop Detection | 潜在循环检测] + ⚠️ Read: 5 consecutive calls (potential loop) +``` + +#### 会话状态 + +``` +[abcdef12] You: !status + +[Session Status | 会话状态] + Session ID: abcdef12 + +[Runtime Settings | 运行时配置] + Model: claude-opus-4-7 + Subagent: ✓ Enabled + Plan Mode: ✗ Disabled + Thinking: ✓ Enabled + +[Session Metrics | 会话指标] + Checkpoints: 42 + Recursion Limit: 1000 + ⚠️ Approaching recursion limit (42/1000) +``` + +### 错误处理 + +当检测到Error时,系统会自动显示: + +``` +[Critical Error | 严重错误] Error ... + +[Traceback | 堆栈跟踪] +... + +[Session Status at Error | 错误发生时的会话状态] + Subagent: ✓ Enabled + Plan Mode: ✗ Disabled + +[Troubleshooting | 故障排除] + 1. 使用 !status 查看完整会话状态 + 2. 使用 !diagnose 分析工具调用模式 + 3. 使用 !steps_all 查看已保存的检查点 + 4. 使用 !export_all 导出完整检查点历史 + 5. Subagent 当前已启用 - 尝试关闭: !subagent off +``` + +### 检查点警告 + +当检查点数量接近递归限制时: + +``` +[WARNING] Checkpoints: 850/1000 - Getting close to limit + +⚠️ [CRITICAL] Checkpoints: 920/1000 - Approaching recursion limit! + Subagent is enabled - consider disabling with: !subagent off +``` + +## 架构设计 + +### 核心组件 + +1. **cli.py**:命令行接口,处理用户输入和命令解析 +2. **engine.py**:核心引擎,管理会话生命周期和代理交互 +3. **session_store.py**:会话存储,异步持久化会话元数据 + +### 关键设计 + +- **单例模式**:整个应用只有一个引擎实例 +- **每会话独立数据库**:每个会话对应独立的SQLite文件,消除锁竞争 +- **异步写入**:后台线程处理文件写入,不阻塞主线程 +- **优雅关闭**:确保所有资源被正确释放 +- **诊断能力**:工具调用模式分析,辅助问题排查 + +## 许可证 + +MIT License + +--- + +**DeerFlow CLI** - DeerFlow AI 代理命令行工具 diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cli/cli.py b/cli/cli.py new file mode 100644 index 0000000000..dda1449954 --- /dev/null +++ b/cli/cli.py @@ -0,0 +1,463 @@ +""" +DeerFlow Production Engine CLI + +Command-line interface for the DeerFlow Production Engine. +Provides interactive chat, session management, and configuration commands. + +Author: heart-scalpel +License: MIT +""" + +import sys +import traceback +from engine import DeerFlowProductionEngine + + +def safe_input(prompt): + """ + Safely read input with UTF-8 encoding. + + Handles encoding errors and EOF gracefully. + + Args: + prompt: The input prompt to display. + + Returns: + str: The input line, stripped of trailing newline. + """ + try: + sys.stdin.reconfigure(encoding='utf-8') + sys.stdout.reconfigure(encoding='utf-8') + except AttributeError: + pass + while True: + try: + return input(prompt).rstrip('\n') + except UnicodeDecodeError: + print("\n[Error] Input encoding error. Please use UTF-8. | 输入编码错误,请使用UTF-8\n") + except EOFError: + return "" + + +def multi_line_input(prompt): + """ + Read multi-line input from the user. + + Args: + prompt: The prompt to display before entering multi-line mode. + + Returns: + str: The combined multi-line input. + """ + print(prompt) + print("Enter !end to finish multi-line input | 输入 !end 结束多行输入\n") + lines = [] + while True: + try: + line = input() + if line.strip().lower() == '!end': + break + lines.append(line) + except UnicodeDecodeError: + print("\n[Error] Input encoding error. Please use UTF-8. | 输入编码错误,请使用UTF-8\n") + except EOFError: + break + return '\n'.join(lines) + + +def main(): + """Main entry point for the DeerFlow Production Engine CLI.""" + engine = DeerFlowProductionEngine() + + print("=" * 70) + print("DeerFlow Production Engine - Local Test Mode") + print("DeerFlow 生产引擎 - 本地测试模式") + print("=" * 70) + print("Type !help to see all available commands | 输入 !help 查看所有可用命令") + print("Type !multi to enter multi-line input mode | 输入 !multi 进入多行输入模式") + print("=" * 70) + print() + + while True: + try: + if engine.current_session_id is None: + engine.create_session() + + # Show only first 8 characters of session ID for brevity + prompt = f"[{engine.current_session_id[:8]}] You: " + user_input = safe_input(prompt).strip() + + if not user_input: + continue + + if user_input.lower() == '!multi': + user_input = multi_line_input(f"\n[{engine.current_session_id[:8]}] Multi-line Input Mode | 多行输入模式") + if not user_input.strip(): + print("\n[Info] Empty input ignored | 空输入已忽略\n") + continue + print() + + if user_input.lower() == "!exit": + break + + if user_input.lower() == "!help": + print("\n[Available Commands | 可用命令]") + print(" Session Management | 会话管理:") + print(" !new [id] [title] Create new session | 创建新会话") + print(" !switch <id> Switch to session | 切换会话") + print(" !delete session <id> Delete session | 删除会话") + print(" !rename <title> Rename current session | 重命名当前会话") + print(" !archive <id> Archive session | 归档会话") + print(" !archives List archived sessions | 查看归档会话") + print(" !restore <id> Restore session from archive | 从归档恢复会话") + print(" !sessions List all sessions | 列出所有会话") + print(" !export Export current session to Markdown | 导出当前会话为Markdown") + print(" !export_all Export all checkpoints to Markdown | 导出全部检查点为Markdown") + print(" !search <keyword> Search all sessions | 搜索所有会话") + print(" Debugging | 诊断流程:") + print(" !steps Show current session steps | 查看当前会话步骤") + print(" !steps_all Show all checkpoints (de‑duplicated) | 查看全部检查点(去重)") + print(" !diagnose Analyze tool call patterns for loops | 分析工具调用模式检测循环") + # print(" !back <N> Go back to step N | 回退到第N步") + # print(" !back_cp <N> Go back to checkpoint N (from !steps_all) | 回退到第N个检查点") + print(" File Management | 文件管理:") + print(" !upload <path> Upload file | 上传文件") + print(" !files List uploaded files | 列出上传文件") + print(" !delete <file> Delete uploaded file | 删除上传文件") + print(" Models & Skills | 模型与技能:") + print(" !models List available models | 列出可用模型") + print(" !use <model> Switch model | 切换模型") + print(" !skills List available skills | 列出可用技能") + print(" !enable <skill> Enable skill | 启用技能") + print(" !disable <skill> Disable skill | 禁用技能") + print(" Runtime Modes | 运行模式:") + print(" !status Show current session status and settings | 显示当前会话状态") + print(" !recursion_limit <N> Set recursion limit (default: 1000) | 设置递归限制") + print(" !plan on/off Enable/disable plan mode | 开启/关闭计划模式") + print(" !subagent on/off Enable/disable subagent delegation | 开启/关闭子代理") + print(" Memory System | 记忆系统:") + print(" !memory Show current memory | 查看当前记忆") + print(" !clear Clear current session memory | 清空当前会话记忆") + print(" Input | 输入:") + print(" !multi Enter multi-line input mode | 进入多行输入模式") + print(" System | 系统:") + print(" !help Show this help message | 显示帮助信息") + print(" !exit Exit the system | 退出系统") + print() + continue + + if user_input.lower().startswith("!new"): + parts = user_input.split(maxsplit=2) + sid = parts[1] if len(parts) > 1 else None + title = parts[2] if len(parts) > 2 else None + engine.create_session(sid, title) + continue + + if user_input.lower().startswith("!switch"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !switch <session_id> | 用法: !switch <会话ID>") + continue + engine.switch_session(parts[1]) + continue + + if user_input.lower().startswith("!delete session"): + parts = user_input.split() + if len(parts) < 3: + print("[Error] Usage: !delete session <session_id> | 用法: !delete session <会话ID>") + continue + engine.delete_session(parts[2]) + continue + + if user_input.lower().startswith("!rename"): + parts = user_input.split(maxsplit=1) + if len(parts) < 2: + print("[Error] Usage: !rename <new_title> | 用法: !rename <新标题>") + continue + engine.rename_session(engine.current_session_id, parts[1]) + continue + + if user_input.lower() == "!archives": + engine.list_archives() + continue + + if user_input.lower().startswith("!archive"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !archive <session_id> | 用法: !archive <会话ID>") + continue + engine.archive_session(parts[1]) + continue + + if user_input.lower().startswith("!restore"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !restore <session_id> | 用法: !restore <会话ID>") + continue + engine.restore_archive(parts[1]) + continue + + if user_input.lower() == "!sessions": + engine.list_sessions() + continue + + if user_input.lower() == "!export": + engine.export_session_markdown() + continue + + if user_input.lower() == "!export_all": + engine.export_all_checkpoints() + continue + + if user_input.lower().startswith("!search"): + parts = user_input.split(maxsplit=1) + if len(parts) < 2: + print("[Error] Usage: !search <keyword> | 用法: !search <关键词>") + continue + engine.search_sessions(parts[1]) + continue + + if user_input.lower() == "!steps": + steps = engine.get_session_steps() + print("\n[Step List | 步骤列表]") + for step in steps: + preview = step["user_input"][:60] + "..." if len(step["user_input"]) > 60 else step["user_input"] + print(f" {step['step']}. {preview}") + print() + continue + + if user_input.lower() == "!steps_all": + cps = engine.get_all_checkpoint_steps() + print(f"\n[All Checkpoints | 全部检查点] Total: {len(cps)}\n") + if not cps: + print(" No checkpoints found. | 未找到检查点。\n") + continue + for idx, cp in enumerate(cps, 1): + ts_display = str(cp["ts"]) if cp["ts"] is not None else "N/A" + new_flag = "✓ New content | 有新内容" if cp["has_new_content"] else "✗ No new content | 无新增" + print(f" [{idx}] {cp['checkpoint_id'][:8]}... | ts:{ts_display} | {new_flag}") + print() + continue + + if user_input.lower() == "!diagnose": + engine.diagnose_tool_calls() + continue + + # TODO: Rollback commands — checkpoint_id passthrough to DeerFlowClient is not yet implemented. + # See client._get_runnable_config (missing configurable["checkpoint_id"]). + # if user_input.lower().startswith("!back_cp"): + # parts = user_input.split() + # if len(parts) < 2: + # print("[Error] Usage: !back_cp <checkpoint_index> | 用法: !back_cp <检查点索引>") + # continue + # try: + # cp_idx = int(parts[1]) + # cps = engine.get_all_checkpoint_steps() + # if cp_idx < 1 or cp_idx > len(cps): + # print(f"[Error] Checkpoint index must be between 1 and {len(cps)} | 检查点索引必须在1到{len(cps)}之间") + # continue + # target_cp = cps[cp_idx - 1] + # print(f"\n[Rollback to Checkpoint | 回退到检查点] Index {cp_idx} | ID: {target_cp['checkpoint_id']}") + # print("AI: ", end="", flush=True) + # for chunk in engine.chat("Continue analysis from this checkpoint | 继续从该检查点分析", + # checkpoint_id=target_cp["checkpoint_id"]): + # print(chunk, end="", flush=True) + # print("\n") + # except ValueError: + # print("[Error] Invalid checkpoint index | 无效的检查点索引") + # continue + + # if user_input.lower().startswith("!back"): + # parts = user_input.split() + # if len(parts) < 2: + # print("[Error] Usage: !back <step_number> | 用法: !back <步骤号>") + # continue + # try: + # step_num = int(parts[1]) + # steps = engine.get_session_steps() + # if step_num < 1 or step_num > len(steps): + # print(f"[Error] Step must be between 1 and {len(steps)} | 步骤必须在1到{len(steps)}之间") + # continue + # target_step = steps[step_num - 1] + # print(f"\n[Rollback | 回溯] Reverted to step {step_num} | 已回退到步骤 {step_num}") + # print(f"Context | 上下文: {target_step['user_input']}\n") + # print("AI: ", end="", flush=True) + # for chunk in engine.chat("Continue analysis from here | 继续从这里开始分析", checkpoint_id=target_step["checkpoint_id"]): + # print(chunk, end="", flush=True) + # print("\n") + # except ValueError: + # print("[Error] Invalid step number | 无效的步骤号") + # continue + + if user_input.lower().startswith("!upload"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !upload <file_path> | 用法: !upload <文件路径>") + continue + engine.upload_file(parts[1]) + continue + + if user_input.lower() == "!files": + listing = engine.list_uploads() + if listing and listing.get("count", 0) > 0: + print("\n[Uploaded Files | 上传文件]") + for f in listing["files"]: + print(f" {f['filename']} | {f['size']} bytes") + print() + else: + print("\n[Uploaded Files | 上传文件] No files uploaded | 暂无文件\n") + continue + + if user_input.lower().startswith("!delete") and not user_input.lower().startswith("!delete session"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !delete <filename> | 用法: !delete <文件名>") + continue + engine.delete_upload(parts[1]) + continue + + if user_input.lower() == "!models": + models = engine.client.list_models()["models"] + print("\n[Available Models | 可用模型]") + for m in models: + status = "✓ Current | 当前使用" if m["name"] == engine.client._model_name else "" + thinking = "✓ Supports thinking | 支持思考" if m["supports_thinking"] else "" + print(f" {m['name']} | {m['display_name']} {thinking} {status}") + print() + continue + + if user_input.lower().startswith("!use"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !use <model_name> | 用法: !use <模型名>") + continue + engine.switch_model(parts[1]) + continue + + if user_input.lower() == "!skills": + skills = engine.client.list_skills()["skills"] + print("\n[Available Skills | 可用技能]") + for s in skills: + status = "✓ Enabled | 已启用" if s["enabled"] else "✗ Disabled | 已禁用" + print(f" {s['name']} | {s['category']} | {status}") + print() + continue + + if user_input.lower().startswith("!enable"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !enable <skill_name> | 用法: !enable <技能名>") + continue + engine.enable_skill(parts[1]) + continue + + if user_input.lower().startswith("!disable"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !disable <skill_name> | 用法: !disable <技能名>") + continue + engine.disable_skill(parts[1]) + continue + + if user_input.lower() == "!status": + engine.show_status() + continue + + if user_input.lower().startswith("!plan"): + parts = user_input.split() + if len(parts) < 2 or parts[1] not in ["on", "off"]: + print("[Error] Usage: !plan on|off | 用法: !plan on|off") + continue + if parts[1] == "on": + engine.enable_plan_mode() + else: + engine.disable_plan_mode() + continue + + if user_input.lower().startswith("!subagent"): + parts = user_input.split() + if len(parts) < 2 or parts[1] not in ["on", "off"]: + print("[Error] Usage: !subagent on|off | 用法: !subagent on|off") + continue + if parts[1] == "on": + engine.enable_subagent() + else: + engine.disable_subagent() + continue + + if user_input.lower().startswith("!recursion_limit"): + parts = user_input.split() + if len(parts) < 2: + print("[Error] Usage: !recursion_limit <value> | 用法: !recursion_limit <数值>") + continue + try: + limit = int(parts[1]) + engine.set_recursion_limit(limit) + except ValueError: + print("[Error] Recursion limit must be an integer | 递归限制必须是整数") + continue + + if user_input.lower() == "!memory": + memory = engine.client.get_memory() + facts = memory.get("facts", []) + print("\n[Current Memory | 当前记忆]") + if not facts: + print(" No memory facts available | 暂无记忆事实") + else: + for i, fact in enumerate(facts, 1): + print(f" {i}. [{fact['category']}] {fact['content']} (Confidence | 置信度: {fact['confidence']:.2f})") + print() + continue + + if user_input.lower() == "!clear": + engine.client.clear_memory() + print("\n[Memory | 记忆] Cleared current session memory | 已清空当前会话记忆\n") + continue + + # Normal chat interaction + print("AI: ", end="", flush=True) + for chunk in engine.chat(user_input): + print(chunk, end="", flush=True) + print("\n") + + except KeyboardInterrupt: + engine.shutdown() + break + except RecursionError as e: + print(f"\n\n[Critical Error | 严重错误] RecursionError: {str(e)}") + print("\n[Traceback | 堆栈跟踪]") + traceback.print_exc() + + # Show current status to help diagnose + print("\n[Session Status at Error | 错误发生时的会话状态]") + try: + client = engine.client + if client: + subagent_enabled = getattr(client, '_subagent_enabled', False) + plan_mode = getattr(client, '_plan_mode', False) + print(f" Subagent: {'✓ Enabled' if subagent_enabled else '✗ Disabled'}") + print(f" Plan Mode: {'✓ Enabled' if plan_mode else '✗ Disabled'}") + except Exception: + print(" (Unable to retrieve status)") + + print("\n[Debug Info | 调试信息]") + print(f" Current Session: {engine.current_session_id}") + print(f" Checkpoint DB: {engine._get_checkpoint_path(engine.current_session_id)}") + print("\n[Troubleshooting | 故障排除]") + print(" 1. 使用 !status 查看完整会话状态") + print(" 2. 使用 !diagnose 分析工具调用模式") + print(" 3. 使用 !steps_all 查看已保存的检查点") + print(" 4. 使用 !export_all 导出完整检查点历史") + if engine.client and getattr(engine.client, '_subagent_enabled', False): + print(" 5. Subagent 当前已启用 - 尝试关闭: !subagent off") + print() + print() + except Exception as e: + print(f"\n\n[Error | 错误] {type(e).__name__}: {str(e)}") + print("\n[Traceback | 堆栈跟踪]") + traceback.print_exc() + print() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/cli/docker-compose.yaml b/cli/docker-compose.yaml new file mode 100644 index 0000000000..136c9a5f1f --- /dev/null +++ b/cli/docker-compose.yaml @@ -0,0 +1,12 @@ +services: + app: + container_name: deer-flow-cli + image: deer-flow-cli + build: + context: .. + dockerfile: cli/Dockerfile + volumes: + - ..:/deer-flow + env_file: + - ../.env + command: ["tail", "-f", "/dev/null"] diff --git a/cli/engine.py b/cli/engine.py new file mode 100644 index 0000000000..557fe503f2 --- /dev/null +++ b/cli/engine.py @@ -0,0 +1,1329 @@ +""" +DeerFlow Production Engine + +A production-grade, session-aware runtime engine for DeerFlow AI agents. +Features complete session management, persistence, streaming, and tool integration. + +Isolation design +================ +Session-level isolation is enforced through two complementary mechanisms: + +1. **Per-session DeerFlowClient instances.** Each session owns an independent + DeerFlowClient and SQLite checkpointer. Agent state, runtime settings, and + conversation history never leak across sessions. + +2. **Global shared-resource reset on session switch.** The deerflow backend + package holds module-level singletons (MCP session pool, MCP tool cache, + subagent background tasks, memory queue, etc.) that outlive any single + client instance. ``_reset_shared_resources()`` clears every known global + on each user-initiated session switch. + + This is a *best-effort* reset — it depends on the backend exposing public + reset functions for every piece of mutable global state. If a future + backend version adds new global state without a corresponding reset API, + it will not be caught here. For single-user CLI use this trade-off is + acceptable; a multi-tenant server should spawn a subprocess per session + for guaranteed isolation (see issue #3292). + +All checkpoints are preserved for full model behavior auditing. + +Author: heart-scalpel +License: MIT +""" + +import os +import time +import re +import uuid +import json +from pathlib import Path + +from langgraph.checkpoint.sqlite import SqliteSaver +from deerflow.client import DeerFlowClient +from session_store import SessionStore + +# Configuration constants +WORK_DIR = Path("./.deer-flow") +SESSIONS_DIR = WORK_DIR / "deerflow_sessions" +ARCHIVE_DIR = SESSIONS_DIR / "archive" + + +class DeerFlowProductionEngine: + """ + Production-grade singleton engine for DeerFlow agent execution. + + Manages session lifecycle, persistence, streaming responses, and agent + configuration. Each session owns an independent DeerFlowClient and + SQLite checkpointer so that agent state, runtime settings, and + conversation history are fully isolated. + + On every user-initiated session switch, ``_reset_shared_resources()`` + clears all known module-level globals in the deerflow backend: + MCP session pool + tool cache, subagent background tasks + usage cache, + memory storage cache + update queue. See that method's docstring for + the full inventory and the residual risks of this best-effort approach. + """ + + _instance = None + _initialized = False + + def __new__(cls): + """Create or return the singleton instance.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + """Initialize the engine if not already initialized.""" + if self._initialized: + return + self._initialized = True + + # ------------------------------------------------------------------ + # Session persistence + # ------------------------------------------------------------------ + self.store = SessionStore(SESSIONS_DIR, ARCHIVE_DIR) + + # ------------------------------------------------------------------ + # Per-session client instances (complete agent state isolation) + # ------------------------------------------------------------------ + self._clients: dict[str, DeerFlowClient] = {} + self._checkpointer_cms: dict[str, object] = {} + self._checkpointers: dict[str, object] = {} + + self.current_session_id = None + + # Runtime settings template — applied to every new client so that + # user preferences (model, plan mode, subagent, thinking) survive + # across session switches. + self._runtime_settings = { + "model_name": None, + "plan_mode": False, + "subagent_enabled": False, + "thinking_enabled": True, + "recursion_limit": 1000, + } + + # ------------------------------------------------------------------ + # Bootstrap: load existing sessions or create a default + # ------------------------------------------------------------------ + if not self.store.sessions: + self._create_default_session() + else: + first_session_id = next(iter(self.store.sessions.keys())) + self._activate_session(first_session_id) + + # ------------------------------------------------------------------ + # Checkpointer paths + # ------------------------------------------------------------------ + + def _get_checkpoint_path(self, session_id: str) -> Path: + """Get the database path for a specific session.""" + return SESSIONS_DIR / f"{session_id}_checkpoints.db" + + def _get_archive_checkpoint_path(self, session_id: str) -> Path: + """Get the archived database path for a specific session.""" + return ARCHIVE_DIR / f"{session_id}_checkpoints.db" + + # ------------------------------------------------------------------ + # Client property + # ------------------------------------------------------------------ + + @property + def client(self) -> DeerFlowClient | None: + """Return the DeerFlowClient for the current session.""" + if self.current_session_id is None: + return None + return self._clients.get(self.current_session_id) + + # ------------------------------------------------------------------ + # Per-session client and checkpointer management + # ------------------------------------------------------------------ + + def _get_or_create_client(self, session_id: str) -> DeerFlowClient: + """Return (or create) the DeerFlowClient for *session_id*. + + Ensures the session's SQLite checkpointer is open. Does **not** + reset shared global resources — use ``_activate_session`` for + user-initiated switches that require a full reset. + """ + # Ensure a live checkpointer for this session. + if session_id not in self._checkpointer_cms: + db_path = self._get_checkpoint_path(session_id) + cm = SqliteSaver.from_conn_string(str(db_path)) + self._checkpointer_cms[session_id] = cm + self._checkpointers[session_id] = cm.__enter__() + + checkpointer = self._checkpointers[session_id] + + # Return existing client after refreshing its checkpointer ref + # (the old SqliteSaver may have been closed across a switch). + if session_id in self._clients: + client = self._clients[session_id] + client._checkpointer = checkpointer + return client + + # First time this session is used — create a fresh client. + client = DeerFlowClient(checkpointer=checkpointer) + settings = self._runtime_settings + if settings["model_name"]: + client._model_name = settings["model_name"] + client._plan_mode = settings["plan_mode"] + client._subagent_enabled = settings["subagent_enabled"] + client._thinking_enabled = settings["thinking_enabled"] + + self._clients[session_id] = client + return client + + def _activate_session(self, session_id: str): + """Activate *session_id* for user interaction. + + Closes the previous session's checkpointer (keeping its client + alive for later reuse), ensures the target session has a live + client and checkpointer, then resets shared global resources. + """ + # Close the *previous* session's checkpointer to release the + # SQLite connection, but leave its client in the dict so runtime + # settings are preserved. + if self.current_session_id and self.current_session_id != session_id: + self._close_checkpointer(self.current_session_id) + + # Ensure target session is ready. + self._get_or_create_client(session_id) + + # Prevent cross-session contamination from module-level globals. + self._reset_shared_resources() + + self.current_session_id = session_id + + def _close_checkpointer(self, session_id: str): + """Close the SQLite checkpointer for *session_id*.""" + cm = self._checkpointer_cms.pop(session_id, None) + if cm is not None: + cm.__exit__(None, None, None) + self._checkpointers.pop(session_id, None) + + def _destroy_client(self, session_id: str): + """Fully tear down a session's client and checkpointer.""" + self._clients.pop(session_id, None) + self._close_checkpointer(session_id) + + def _reset_shared_resources(self): + """Reset all known module-level globals in the deerflow backend. + + This method is the single point of accountability for cross-session + cleanup. Every piece of mutable global state discovered in the + backend audit that has a public reset API is cleared here. + + Resources intentionally NOT reset: + - ``_isolated_subagent_loop`` — persistent event loop, expensive to + recreate, carries no session-specific state, atexit-managed. + - ``_scheduler_pool`` — ThreadPoolExecutor, expensive, stateless. + - ``_SYNC_TOOL_EXECUTOR``, ``_SYNC_MEMORY_UPDATER_EXECUTOR`` — + stateless thread pools, atexit-managed. + - Middleware dicts (todo, loop_detection) — keyed by + ``(thread_id, run_id)``, naturally scoped; no public reset API. + + Fragility note: if a future backend version adds new global state + without a corresponding reset function, it will NOT be caught here. + This is the fundamental limitation of option 2 vs subprocess + isolation. For single-user CLI use this trade-off is acceptable. + """ + # ── MCP layer ─────────────────────────────────────────────── + # Closes all persistent MCP sessions, clears _mcp_tools_cache, + # resets _pool singleton, resets _cache_initialized / _config_mtime. + try: + from deerflow.mcp.cache import ( + reset_mcp_tools_cache, + ) + + reset_mcp_tools_cache() + except Exception: + pass + + # ── Subagent layer ────────────────────────────────────────── + # _background_tasks: dict of SubagentResult keyed by task_id. + # Without clearing, list_background_tasks() exposes stale results + # from the previous session. + try: + from deerflow.subagents.executor import ( + _background_tasks, + _background_tasks_lock, + ) + + with _background_tasks_lock: + _background_tasks.clear() + except Exception: + pass + + # _subagent_usage_cache: dict of token usage keyed by tool_call_id. + # Stale entries from completed/abandoned sessions persist until + # explicitly cleared. + try: + from deerflow.tools.builtins.task_tool import ( + _subagent_usage_cache, + ) + + _subagent_usage_cache.clear() + except Exception: + pass + + # ── Memory layer ──────────────────────────────────────────── + # _storage_instance (FileMemoryStorage) holds an in-memory cache + # of facts keyed by (user_id, agent_name). reload() re-reads + # from disk so the next session picks up any file-system changes. + try: + from deerflow.agents.memory.storage import get_memory_storage + + storage = get_memory_storage() + if hasattr(storage, "reload"): + storage.reload() + except Exception: + pass + + # _memory_queue batches ConversationContext objects across + # sessions. reset_memory_queue() drains the queue and replaces + # the singleton so queued contexts from the old session are not + # flushed into the new one. + try: + from deerflow.agents.memory.queue import reset_memory_queue + + reset_memory_queue() + except Exception: + pass + + # ------------------------------------------------------------------ + # Session lifecycle + # ------------------------------------------------------------------ + + def _create_default_session(self): + """Create a default session when no sessions exist.""" + return self.create_session(title="New Session") + + def _ensure_current_session(self): + """Ensure a valid current session exists.""" + if self.current_session_id is None or self.current_session_id not in self.store.sessions: + if self.store.sessions: + first_session_id = next(iter(self.store.sessions.keys())) + self._activate_session(first_session_id) + else: + self._create_default_session() + + def shutdown(self): + """Gracefully shut down the engine and release all resources.""" + print("\n[Engine] Shutting down gracefully...") + self.store.shutdown() + for sid in list(self._clients.keys()): + self._destroy_client(sid) + for sid in list(self._checkpointer_cms.keys()): + self._close_checkpointer(sid) + print("[Engine] Shutdown complete") + + def create_session(self, session_id=None, title=None): + """ + Create a new conversation session with its own isolated database + and DeerFlowClient instance. + + Args: + session_id: Optional custom session ID. Auto-generated if None. + title: Optional session title. Defaults to "New Session". + + Returns: + str: The created session ID. + """ + if session_id is None or not re.fullmatch(r'[\w-]+', session_id): + session_id = uuid.uuid4().hex + if session_id in self.store.sessions: + print(f"[Session] ID already exists: {session_id}") + return session_id + self.store.sessions[session_id] = { + "created_at": time.time(), + "last_active": time.time(), + "title": title or "New Session", + "last_checkpoint_id": None, + } + self.store.session_metrics[session_id] = { + "total_tokens": 0, + "tool_calls": 0, + "turns": 0, + } + self.store.save_async(session_id) + + self._activate_session(session_id) + + print(f"[Session] Created: {session_id}") + return session_id + + def switch_session(self, session_id): + """ + Switch to an existing session with complete state isolation. + + Closes the previous session's checkpointer, activates the target + session's client, and resets shared global resources (MCP pool, + subagent background tasks). + + Args: + session_id: ID of the session to switch to. + + Returns: + bool: True if switch was successful, False otherwise. + """ + if session_id not in self.store.sessions: + print(f"[Error] Session {session_id} not found") + return False + + self._activate_session(session_id) + self.store.sessions[session_id]["last_active"] = time.time() + self.store.save_async(session_id) + + print(f"[Session] Switched to: {session_id}") + return True + + def delete_session(self, session_id): + """ + Delete a session and all associated files including its database. + + Args: + session_id: ID of the session to delete. + + Returns: + bool: True if deletion was successful, False otherwise. + """ + if session_id not in self.store.sessions: + print(f"[Error] Session {session_id} not found") + return False + + self._destroy_client(session_id) + + self.store.delete_session_files(session_id) + db_path = self._get_checkpoint_path(session_id) + if db_path.exists(): + db_path.unlink() + + if self.current_session_id == session_id: + self.current_session_id = None + self._ensure_current_session() + + print(f"[Session] Deleted: {session_id}") + return True + + def rename_session(self, session_id, new_title): + """ + Rename an existing session. + + Args: + session_id: ID of the session to rename. + new_title: New title for the session. + + Returns: + bool: True if rename was successful, False otherwise. + """ + if session_id not in self.store.sessions: + print(f"[Error] Session {session_id} not found") + return False + self.store.sessions[session_id]["title"] = new_title + self.store.save_async(session_id) + print(f"[Session] Renamed to: {new_title}") + return True + + def archive_session(self, session_id): + """ + Archive a session, moving all files including its database to the + archive directory. + + Args: + session_id: ID of the session to archive. + + Returns: + bool: True if archiving was successful, False otherwise. + """ + if session_id not in self.store.sessions: + print(f"[Error] Session {session_id} not found") + return False + + self._destroy_client(session_id) + + self.store.archive_session_files(session_id) + db_path = self._get_checkpoint_path(session_id) + archive_db_path = self._get_archive_checkpoint_path(session_id) + if db_path.exists(): + db_path.rename(archive_db_path) + + if self.current_session_id == session_id: + self.current_session_id = None + self._ensure_current_session() + + print(f"[Session] Archived: {session_id}") + return True + + def list_archives(self): + """Print a list of all archived sessions.""" + print("\n[Archived Sessions]") + archives = list(ARCHIVE_DIR.glob("*.json")) + if not archives: + print(" No archived sessions") + else: + for f in archives: + print(f" {f.stem}") + print() + + def list_sessions(self): + """Print a list of all active sessions with their metrics.""" + print("\n[Session List]") + for sid, info in self.store.sessions.items(): + metrics = self.store.session_metrics[sid] + current = "← Current" if sid == self.current_session_id else "" + title = info.get("title", "New Session") + print( + f" {sid} | {title} | Turns: {metrics['turns']} | " + f"Tokens: {metrics['total_tokens']} {current}" + ) + print() + + # ------------------------------------------------------------------ + # Read-only session introspection (no global reset) + # ------------------------------------------------------------------ + + def _extract_steps(self, session_id: str): + """Extract structured steps from a session's checkpoint history. + + Returns a list of step dicts without switching the active session. + """ + client = self._get_or_create_client(session_id) + thread_data = client.get_thread(session_id) + checkpoints = thread_data.get("checkpoints", []) + if not checkpoints: + return [] + + seen_message_ids: set[str] = set() + steps: list[dict] = [] + current_step = None + + for cp_idx, cp in enumerate(checkpoints): + messages = cp["values"].get("messages", []) + + for msg in messages: + msg_id = msg.get("id") + if msg_id is None: + msg_id = f"__no_id__:{msg.get('type', '')}:{msg.get('content', '')}" + is_duplicate = msg_id in seen_message_ids + + if not is_duplicate: + seen_message_ids.add(msg_id) + + if msg["type"] == "human": + if current_step: + steps.append(current_step) + current_step = { + "step": len(steps) + 1, + "checkpoint_id": cp.get("checkpoint_id"), + "parent_checkpoint_id": cp.get("parent_checkpoint_id"), + "ts": cp.get("ts"), + "total_tokens": cp["values"].get("total_tokens"), + "user_input": msg["content"], + "user_files": msg.get("metadata", {}).get("files", []), + "ai_response": "", + "tool_calls": [], + "ai_response_metadata": {}, + "duplicate_messages": [], + } + elif msg["type"] == "ai" and current_step: + current_step["ai_response"] += msg.get("content", "") + current_step["ai_response_metadata"] = msg.get( + "response_metadata", {} + ) + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + current_step["tool_calls"].append({ + "id": tc["id"], + "name": tc["name"], + "args": tc["args"], + "result": "", + "is_duplicate": False, + }) + elif msg["type"] == "tool" and current_step: + for tc in current_step["tool_calls"]: + if tc["id"] == msg["tool_call_id"]: + tc["result"] = msg.get("content", "") + break + else: + if current_step: + current_step["duplicate_messages"].append({ + "type": msg["type"], + "checkpoint_id": cp.get("checkpoint_id"), + "checkpoint_index": cp_idx, + }) + + if current_step: + steps.append(current_step) + + # Mark duplicate tool calls + seen_tool_call_ids: set[str] = set() + for step in steps: + for tc in step["tool_calls"]: + if tc["id"] in seen_tool_call_ids: + tc["is_duplicate"] = True + else: + seen_tool_call_ids.add(tc["id"]) + + return steps + + def get_session_steps(self, session_id=None): + """ + Get structured conversation steps with duplicate detection. + + Uses the per-session client directly — no global session switch, + so shared resources are left untouched. + + Args: + session_id: ID of the session. Uses current session if None. + + Returns: + list: List of step dictionaries containing conversation data. + """ + session_id = session_id or self.current_session_id + if not session_id: + return [] + return self._extract_steps(session_id) + + def get_all_checkpoint_steps(self, session_id=None): + """ + Get a list of all checkpoints as individual steps. + + Only messages newly appearing in each checkpoint are shown + (duplicates hidden). Every checkpoint is preserved for precise + rollback and auditing. + + Uses the per-session client directly — no global session switch. + + Args: + session_id: Session ID, uses current session if None. + + Returns: + list[dict]: Each element represents a checkpoint. + """ + session_id = session_id or self.current_session_id + if not session_id: + return [] + + client = self._get_or_create_client(session_id) + thread_data = client.get_thread(session_id) + checkpoints = thread_data.get("checkpoints", []) + if not checkpoints: + return [] + + seen_message_ids: set[str] = set() + checkpoint_steps: list[dict] = [] + + for cp in checkpoints: + messages = cp["values"].get("messages", []) + new_msgs = [] + for msg in messages: + msg_id = msg.get("id") + if msg_id is None: + msg_id = ( + f"__no_id__:{msg.get('type', '')}:{msg.get('content', '')}" + ) + if msg_id not in seen_message_ids: + seen_message_ids.add(msg_id) + new_msgs.append(msg) + + checkpoint_steps.append({ + "checkpoint_id": cp.get("checkpoint_id"), + "parent_checkpoint_id": cp.get("parent_checkpoint_id"), + "ts": cp.get("ts"), + "new_messages": new_msgs, + "has_new_content": len(new_msgs) > 0, + }) + + return checkpoint_steps + + # ------------------------------------------------------------------ + # Export + # ------------------------------------------------------------------ + + def export_session_markdown(self, session_id=None): + """ + Export a session to a formatted Markdown file. + + Args: + session_id: ID of the session. Uses current session if None. + + Returns: + str: Path to the exported Markdown file, or None on failure. + """ + session_id = session_id or self.current_session_id + if not session_id or session_id not in self.store.sessions: + print("[Error] No active session") + return None + + steps = self.get_session_steps(session_id) + info = self.store.sessions[session_id] + title = info.get("title", "Session Export") + + md = f"# {title}\n\n" + md += f"Session ID: {session_id}\n" + md += f"Created: {time.ctime(info['created_at'])}\n" + md += f"Last Active: {time.ctime(info['last_active'])}\n" + md += f"Total Turns: {len(steps)}\n" + md += f"Total Tokens: {self.store.session_metrics[session_id]['total_tokens']}\n\n" + md += "---\n\n" + + for step in steps: + md += f"## Turn {step['step']}\n\n" + md += f"**User**: {step['user_input']}\n\n" + md += f"**AI**: {step['ai_response']}\n\n" + + if step["tool_calls"]: + md += "**Tool Calls**\n\n" + for tc in step["tool_calls"]: + if tc["is_duplicate"]: + md += f"### {tc['name']} ⚠️ Duplicate\n" + else: + md += f"### {tc['name']}\n" + + md += "**Parameters**:\n" + md += f"```json\n{json.dumps(tc['args'], ensure_ascii=False, indent=2)}\n```\n" + + if tc["result"]: + md += "**Result**:\n" + try: + if isinstance(tc["result"], str): + result_json = json.loads(tc["result"]) + md += f"```json\n{json.dumps(result_json, ensure_ascii=False, indent=2)}\n```\n" + else: + md += f"```json\n{json.dumps(tc['result'], ensure_ascii=False, indent=2)}\n```\n" + except (json.JSONDecodeError, TypeError): + md += f"```\n{tc['result']}\n```\n" + md += "\n" + + if step.get("duplicate_messages"): + duplicate_count = len(step["duplicate_messages"]) + md += ( + f"⚠️ **Note**: {duplicate_count} duplicate messages detected " + "across checkpoints (not shown)\n\n" + ) + + md += "---\n\n" + + session_dir = SESSIONS_DIR / session_id + session_dir.mkdir(parents=True, exist_ok=True) + timestamp = int(time.time()) + filename = session_dir / f"export_{timestamp}.md" + + with open(filename, "w", encoding="utf-8") as f: + f.write(md) + + print(f"[Export] Session exported to: {filename}") + return str(filename) + + def export_all_checkpoints(self, session_id=None): + """ + Export all checkpoints to a Markdown file. + + Duplicate messages are hidden but every checkpoint round is listed. + Checkpoints with no new messages are still included for full + traceability. + + Args: + session_id: Session ID, uses current session if None. + + Returns: + str: Path to the exported Markdown file, or None on failure. + """ + session_id = session_id or self.current_session_id + if not session_id or session_id not in self.store.sessions: + print("[Error] No active session") + return None + + all_steps = self.get_all_checkpoint_steps(session_id) + info = self.store.sessions[session_id] + title = info.get("title", "Session Export All") + + md = f"# {title} (All Checkpoints)\n\n" + md += f"Session ID: {session_id}\n" + md += f"Created: {time.ctime(info['created_at'])}\n" + md += f"Last Active: {time.ctime(info['last_active'])}\n" + md += f"Total Checkpoints: {len(all_steps)}\n" + md += f"Total Tokens: {self.store.session_metrics[session_id]['total_tokens']}\n\n" + md += "---\n\n" + + for idx, step in enumerate(all_steps, 1): + ts_display = str(step["ts"]) if step["ts"] is not None else "Unknown" + md += f"## Checkpoint {idx}\n\n" + md += f"- **ID**: `{step['checkpoint_id']}`\n" + md += f"- **Parent ID**: `{step['parent_checkpoint_id']}`\n" + md += f"- **Time**: {ts_display}\n\n" + + if not step["has_new_content"]: + md += ( + "⚠️ This checkpoint introduced no new messages " + "(content identical to previous checkpoint).\n\n" + ) + else: + for msg in step["new_messages"]: + if msg["type"] == "human": + md += f"### [User]\n\n{msg['content']}\n\n" + elif msg["type"] == "ai": + content = msg.get("content", "") + if content: + md += f"### [AI]\n\n{content}\n\n" + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + md += f"#### [Tool Call: {tc['name']}]\n\n" + md += f"```json\n{json.dumps(tc['args'], ensure_ascii=False, indent=2)}\n```\n\n" + elif msg["type"] == "tool": + result = msg.get("content", "") + md += "#### [Tool Result]\n\n" + try: + result_json = json.loads(result) + md += f"```json\n{json.dumps(result_json, ensure_ascii=False, indent=2)}\n```\n\n" + except (json.JSONDecodeError, TypeError): + md += f"```\n{result}\n```\n\n" + md += "---\n\n" + + session_dir = SESSIONS_DIR / session_id + session_dir.mkdir(parents=True, exist_ok=True) + timestamp = int(time.time()) + filename = session_dir / f"export_all_checkpoints_{timestamp}.md" + + with open(filename, "w", encoding="utf-8") as f: + f.write(md) + + print(f"[Export] All checkpoints exported to: {filename}") + return str(filename) + + # ------------------------------------------------------------------ + # Search + # ------------------------------------------------------------------ + + def search_sessions(self, keyword): + """ + Search all active sessions for a keyword in user inputs or AI + responses. Uses each session's own client directly — no global + session switch, so shared resources are untouched. + + Args: + keyword: The keyword to search for (case-insensitive). + """ + print(f"\n[Search Results for: {keyword}]") + found = False + + for sid in self.store.sessions: + steps = self.get_session_steps(sid) + for step in steps: + if ( + keyword.lower() in step["user_input"].lower() + or keyword.lower() in step["ai_response"].lower() + ): + title = self.store.sessions[sid].get("title", "New Session") + print(f" Session: {sid} | {title} | Turn {step['step']}") + print(f" User: {step['user_input'][:80]}...") + found = True + break + + if not found: + print(" No matching sessions found") + print() + + # ------------------------------------------------------------------ + # Archive restore + # ------------------------------------------------------------------ + + def restore_archive(self, session_id, switch=True): + """ + Restore an archived session including its database file. + + Args: + session_id: ID of the archived session to restore. + switch: Whether to switch to the restored session (default True). + + Returns: + bool: True if restoration was successful, False otherwise. + """ + archive_path = ARCHIVE_DIR / f"{session_id}.json" + archive_db_path = self._get_archive_checkpoint_path(session_id) + + if not archive_path.exists(): + print(f"[Error] Archive {session_id} not found") + return False + if session_id in self.store.sessions: + print(f"[Error] Session {session_id} already active") + return False + + with open(archive_path, "r", encoding="utf-8") as f: + data = json.load(f) + + self.store.sessions[session_id] = data["info"] + self.store.session_metrics[session_id] = data["metrics"] + + active_path = SESSIONS_DIR / f"{session_id}.json" + archive_path.rename(active_path) + + if archive_db_path.exists(): + active_db_path = self._get_checkpoint_path(session_id) + archive_db_path.rename(active_db_path) + + if switch: + self._activate_session(session_id) + + self.store.save_async(session_id) + print(f"[Session] Restored from archive: {session_id}") + return True + + # ------------------------------------------------------------------ + # Chat + # ------------------------------------------------------------------ + + def chat(self, message, session_id=None, **kwargs): + """ + Send a message to the agent and stream the response. + + Args: + message: The user's input message. + session_id: ID of the session. Uses current session if None. + **kwargs: Additional keyword arguments passed to client.stream(). + + Yields: + str: Chunks of the AI response, followed by metrics. + """ + session_id = session_id or self.current_session_id + if not session_id or session_id not in self.store.sessions: + session_id = self.create_session() + + self.store.sessions[session_id]["last_active"] = time.time() + # Use runtime setting for recursion_limit if not specified + if "recursion_limit" not in kwargs: + kwargs["recursion_limit"] = self._runtime_settings.get("recursion_limit", 1000) + stream_kwargs = {"thread_id": session_id, **kwargs} + + full_response = "" + tool_calls = 0 + total_tokens = 0 + + tool_call_history: list[dict] = [] + + client = self._get_or_create_client(session_id) + try: + for event in client.stream(message, **stream_kwargs): + if event.type == "messages-tuple": + d = event.data + if d.get("type") == "ai" and d.get("content"): + content = d["content"] + full_response += content + yield content + if d.get("tool_calls"): + tool_calls += len(d["tool_calls"]) + for tc in d["tool_calls"]: + tool_call_history.append({ + "name": tc.get("name"), + "args": tc.get("args"), + "id": tc.get("id"), + }) + elif event.type == "end": + usage = event.data.get("usage", {}) + total_tokens = usage.get("total_tokens", 0) + except Exception as stream_error: + print(f"\n\n[Stream Error | 流错误] {type(stream_error).__name__}: {stream_error}") + print(f"\n[Tool Call Summary | 工具调用摘要]") + print(f" Total tool calls: {len(tool_call_history)}") + from collections import Counter + tool_counts = Counter(tc["name"] for tc in tool_call_history) + for tool_name, count in tool_counts.most_common(): + print(f" {tool_name}: {count} calls") + print(f"\n[Tool Call History | 工具调用历史] (最近5次)") + for tc in tool_call_history[-5:]: + print(f" - {tc['name']}: {tc.get('args', {})}") + raise + + self.store.session_metrics[session_id]["turns"] += 1 + self.store.session_metrics[session_id]["tool_calls"] += tool_calls + self.store.session_metrics[session_id]["total_tokens"] += total_tokens + + thread_data = client.get_thread(session_id) + if thread_data["checkpoints"]: + last_checkpoint_id = thread_data["checkpoints"][-1]["checkpoint_id"] + self.store.sessions[session_id]["last_checkpoint_id"] = last_checkpoint_id + + if ( + self.store.sessions[session_id].get("title") in (None, "New Session") + and full_response + ): + self.store.sessions[session_id]["title"] = ( + message[:30] + ("..." if len(message) > 30 else "") + ) + + self.store.save_async(session_id) + + yield f"\n\n[Metrics] Tokens: {total_tokens} | Tool Calls: {tool_calls}" + + # Check checkpoint count and warn if approaching recursion limit + if thread_data and thread_data.get("checkpoints"): + checkpoint_count = len(thread_data["checkpoints"]) + recursion_limit = self._runtime_settings.get("recursion_limit", 1000) + warning_threshold = int(recursion_limit * 0.8) # Warn at 80% of limit + critical_threshold = int(recursion_limit * 0.9) # Critical warning at 90% + + if checkpoint_count >= critical_threshold: + yield f"\n\n⚠️ [CRITICAL] Checkpoints: {checkpoint_count}/{recursion_limit} - Approaching recursion limit!" + client = self._get_or_create_client(session_id) + if hasattr(client, '_subagent_enabled') and client._subagent_enabled: + yield f" Subagent is enabled - consider disabling with: !subagent off" + elif checkpoint_count >= warning_threshold: + yield f"\n\n[WARNING] Checkpoints: {checkpoint_count}/{recursion_limit} - Getting close to limit" + + # ------------------------------------------------------------------ + # File upload + # ------------------------------------------------------------------ + + def upload_file(self, file_path, session_id=None): + """ + Upload a file to the current session. + + Args: + file_path: Path to the file to upload. + session_id: ID of the session. Uses current session if None. + + Returns: + dict: Upload result from the client, or None on failure. + """ + session_id = session_id or self.current_session_id + if not session_id: + print("[Error] No active session") + return None + if not os.path.exists(file_path): + print(f"[Error] File not found: {file_path}") + return None + client = self._get_or_create_client(session_id) + result = client.upload_files(session_id, [file_path]) + print(f"[Upload] Success: {result['message']}") + return result + + def list_uploads(self, session_id=None): + """ + List all files uploaded to a session. + + Args: + session_id: ID of the session. Uses current session if None. + + Returns: + dict: List of uploaded files, or None if no active session. + """ + session_id = session_id or self.current_session_id + if not session_id: + print("[Error] No active session") + return None + client = self._get_or_create_client(session_id) + return client.list_uploads(session_id) + + def delete_upload(self, filename, session_id=None): + """ + Delete an uploaded file from a session. + + Args: + filename: Name of the file to delete. + session_id: ID of the session. Uses current session if None. + + Returns: + dict: Deletion result from the client, or None on failure. + """ + session_id = session_id or self.current_session_id + if not session_id: + print("[Error] No active session") + return None + client = self._get_or_create_client(session_id) + return client.delete_upload(session_id, filename) + + # ------------------------------------------------------------------ + # Runtime controls (apply to current client + persist for new sessions) + # ------------------------------------------------------------------ + + def enable_skill(self, skill_name): + """ + Enable a skill for the agent. + + Args: + skill_name: Name of the skill to enable. + + Returns: + bool: True if skill was enabled successfully. + """ + client = self.client + if client is None: + return False + try: + client.update_skill(skill_name, enabled=True) + print(f"[Skill] Enabled: {skill_name}") + return True + except Exception as e: + print(f"[Error] Failed to enable skill: {e}") + return False + + def disable_skill(self, skill_name): + """ + Disable a skill for the agent. + + Args: + skill_name: Name of the skill to disable. + + Returns: + bool: True if skill was disabled successfully. + """ + client = self.client + if client is None: + return False + try: + client.update_skill(skill_name, enabled=False) + print(f"[Skill] Disabled: {skill_name}") + return True + except Exception as e: + print(f"[Error] Failed to disable skill: {e}") + return False + + def switch_model(self, model_name): + """ + Switch the agent to use a different model. + + The choice is persisted to ``_runtime_settings`` so that newly + created sessions inherit it. + + Args: + model_name: Name of the model to use. + + Returns: + bool: True if model was switched successfully. + """ + client = self.client + if client is None: + return False + models = client.list_models()["models"] + if not any(m["name"] == model_name for m in models): + print(f"[Error] Model {model_name} not found") + return False + client._model_name = model_name + self._runtime_settings["model_name"] = model_name + print(f"[Model] Switched to: {model_name}") + return True + + # ------------------------------------------------------------------ + # Runtime controls (apply to current client + persist for new sessions) + # ------------------------------------------------------------------ + + def show_status(self): + """Display current session status and runtime settings.""" + client = self.client + if client is None: + print("[Error] No active session") + return + + print(f"\n[Session Status | 会话状态]") + print(f" Session ID: {self.current_session_id}") + + # Get client runtime settings + model_name = getattr(client, '_model_name', None) or self._runtime_settings.get("model_name") + plan_mode = getattr(client, '_plan_mode', None) if hasattr(client, '_plan_mode') else self._runtime_settings.get("plan_mode") + subagent_enabled = getattr(client, '_subagent_enabled', None) if hasattr(client, '_subagent_enabled') else self._runtime_settings.get("subagent_enabled") + thinking_enabled = getattr(client, '_thinking_enabled', None) if hasattr(client, '_thinking_enabled') else self._runtime_settings.get("thinking_enabled") + + print(f"\n[Runtime Settings | 运行时配置]") + print(f" Model: {model_name or 'default'}") + print(f" Subagent: {'✓ Enabled' if subagent_enabled else '✗ Disabled'}") + print(f" Plan Mode: {'✓ Enabled' if plan_mode else '✗ Disabled'}") + print(f" Thinking: {'✓ Enabled' if thinking_enabled else '✗ Disabled'}") + + # Get checkpoint info + thread_data = client.get_thread(self.current_session_id) + checkpoints = thread_data.get("checkpoints", []) + print(f"\n[Session Metrics | 会话指标]") + print(f" Checkpoints: {len(checkpoints)}") + recursion_limit = self._runtime_settings.get("recursion_limit", 1000) + print(f" Recursion Limit: {recursion_limit}") + if len(checkpoints) > int(recursion_limit * 0.8): + print(f" ⚠️ Approaching recursion limit ({len(checkpoints)}/{recursion_limit})") + print() + + def enable_plan_mode(self): + """Enable plan mode for the agent and persist the setting.""" + client = self.client + if client is None: + return + client._plan_mode = True + self._runtime_settings["plan_mode"] = True + print("[Mode] Plan mode enabled") + + def disable_plan_mode(self): + """Disable plan mode for the agent and persist the setting.""" + client = self.client + if client is None: + return + client._plan_mode = False + self._runtime_settings["plan_mode"] = False + print("[Mode] Plan mode disabled") + + def enable_subagent(self): + """Enable subagent delegation and persist the setting.""" + client = self.client + if client is None: + return + client._subagent_enabled = True + self._runtime_settings["subagent_enabled"] = True + print("[Mode] Subagent delegation enabled") + + def disable_subagent(self): + """Disable subagent delegation and persist the setting.""" + client = self.client + if client is None: + return + client._subagent_enabled = False + self._runtime_settings["subagent_enabled"] = False + print("[Mode] Subagent delegation disabled") + + def set_recursion_limit(self, limit: int): + """ + Set the recursion limit for agent execution and persist the setting. + + Args: + limit: Maximum number of graph steps LangGraph will execute. + + Returns: + bool: True if limit was set successfully. + """ + if not isinstance(limit, int) or limit < 1: + print("[Error] Recursion limit must be a positive integer") + return False + + self._runtime_settings["recursion_limit"] = limit + print(f"[Config] Recursion limit set to: {limit}") + return True + + # ------------------------------------------------------------------ + # Diagnostics + # ------------------------------------------------------------------ + + def diagnose_tool_calls(self, session_id=None): + """ + Analyze tool call patterns from checkpoint history to identify loops. + + Args: + session_id: Session ID, uses current session if None. + """ + session_id = session_id or self.current_session_id + if not session_id: + print("[Error] No active session") + return + + from collections import Counter + + client = self._get_or_create_client(session_id) + thread_data = client.get_thread(session_id) + checkpoints = thread_data.get("checkpoints", []) + + if not checkpoints: + print("[Diagnostics] No checkpoints found") + return + + print(f"\n[Tool Call Diagnostics | 工具调用诊断]") + print(f" Session: {session_id}") + print(f" Total checkpoints: {len(checkpoints)}") + print() + + # Collect tool calls - both WITH and WITHOUT deduplication for comparison + # Each checkpoint's messages contain full history, so we need to track seen IDs + + # 1. Raw collection (WITH duplicates - counts every occurrence across checkpoints) + raw_tool_calls: list[dict] = [] + for cp in checkpoints: + messages = cp["values"].get("messages", []) + for msg in messages: + if msg.get("type") == "ai" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + raw_tool_calls.append({ + "name": tc.get("name"), + "args": tc.get("args"), + "checkpoint_id": cp.get("checkpoint_id"), + "tool_call_id": tc.get("id"), + }) + + # 2. Deduplicated collection (unique tool calls only) + seen_message_ids: set[str] = set() + unique_tool_calls: list[dict] = [] + for cp in checkpoints: + messages = cp["values"].get("messages", []) + for msg in messages: + msg_id = msg.get("id") + # Generate fallback id if message has no id + if msg_id is None: + msg_id = f"__no_id__:{msg.get('type', '')}:{str(msg.get('content', ''))[:100]}" + # Only process new messages + if msg_id not in seen_message_ids: + seen_message_ids.add(msg_id) + if msg.get("type") == "ai" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + unique_tool_calls.append({ + "name": tc.get("name"), + "args": tc.get("args"), + "checkpoint_id": cp.get("checkpoint_id"), + "tool_call_id": tc.get("id"), + }) + + if not unique_tool_calls: + print(" No tool calls found in history") + return + + # Comparison: Raw (with duplicates) vs Unique (deduplicated) + raw_counts = Counter(tc["name"] for tc in raw_tool_calls) + unique_counts = Counter(tc["name"] for tc in unique_tool_calls) + + print(f"[Tool Call Frequency Comparison | 工具调用频率对比]") + print(f" {'Tool Name':<40} | {'Unique':<8} | {'Raw (with dup)':<15} | {'Ratio':<8}") + print(f" {'-'*40}-+-{'-'*8}-+-{'-'*15}-+-{'-'*8}") + + all_tool_names = set(raw_counts.keys()) | set(unique_counts.keys()) + for tool_name in sorted(all_tool_names, key=lambda x: -unique_counts[x]): + unique_cnt = unique_counts.get(tool_name, 0) + raw_cnt = raw_counts.get(tool_name, 0) + ratio = raw_cnt / unique_cnt if unique_cnt > 0 else 0 + ratio_str = f"{ratio:.1f}x" + print(f" {tool_name:<40} | {unique_cnt:<8} | {raw_cnt:<15} | {ratio_str:<8}") + print() + + # Checkpoint density analysis + print(f"[Checkpoint Density Analysis | 检查点密度分析]") + print(f" Unique tool calls: {len(unique_tool_calls)}") + print(f" Raw occurrences across all checkpoints: {len(raw_tool_calls)}") + if len(checkpoints) > 0: + density = len(raw_tool_calls) / max(len(unique_tool_calls), 1) + print(f" Average duplications per unique call: {density:.1f}x") + if density > 3: + print(f" ⚠️ High duplication - each tool call appears in many checkpoints") + print(f" This is normal for long-running sessions with subagents") + print() + + # Detect potential loops - consecutive same-tool calls (using unique data) + print(f"[Potential Loop Detection | 潜在循环检测]") + consecutive_count: dict[str, int] = {} + current_tool = None + current_count = 0 + + for tc in unique_tool_calls: + tool_name = tc["name"] + if tool_name == current_tool: + current_count += 1 + else: + if current_tool: + consecutive_count[current_tool] = max(consecutive_count.get(current_tool, 0), current_count) + current_tool = tool_name + current_count = 1 + if current_tool: + consecutive_count[current_tool] = max(consecutive_count.get(current_tool, 0), current_count) + + for tool_name, count in sorted(consecutive_count.items(), key=lambda x: -x[1]): + if count > 2: + print(f" ⚠️ {tool_name}: {count} consecutive calls (potential loop)") + + print() + print(f"[Recent Tool Call History | 最近工具调用历史] (last 10 unique)") + for tc in unique_tool_calls[-10:]: + args_preview = str(tc.get("args", {}))[:60] + print(f" - {tc['name']}: {args_preview}...") + print() \ No newline at end of file diff --git a/cli/session_store.py b/cli/session_store.py new file mode 100644 index 0000000000..558bd79a95 --- /dev/null +++ b/cli/session_store.py @@ -0,0 +1,175 @@ +""" +Session Store for DeerFlow Production Engine + +Thread-safe asynchronous session persistence layer. +Handles loading, saving, archiving, and deletion of session metadata. +Uses a background worker thread to avoid blocking the main event loop. + +Author: heart-scalpel +License: MIT +""" + +import json +import queue +import threading +from collections import defaultdict +from pathlib import Path + + +class SessionStore: + """ + Thread-safe asynchronous session storage manager. + + Provides non-blocking save operations via a background worker thread. + Maintains in-memory cache of session metadata and metrics for fast access. + Supports session archiving and graceful shutdown. + """ + + def __init__(self, sessions_dir: Path, archive_dir: Path): + """ + Initialize the session store. + + Args: + sessions_dir: Directory to store active session files + archive_dir: Directory to store archived session files + """ + self.sessions_dir = sessions_dir + self.archive_dir = archive_dir + + # Create directories if they don't exist + self.sessions_dir.mkdir(parents=True, exist_ok=True) + self.archive_dir.mkdir(parents=True, exist_ok=True) + + # In-memory session cache + self.sessions = {} + self.session_metrics = defaultdict(lambda: {"total_tokens": 0, "tool_calls": 0, "turns": 0}) + + # Async write infrastructure + self._write_queue = queue.Queue() + self._pending_writes = {} + self._lock = threading.Lock() + self._write_thread = threading.Thread(target=self._write_worker, daemon=True) + self._write_thread.start() + + # Load existing sessions from disk + self._load_sessions_from_disk() + + def _load_sessions_from_disk(self): + """Load all active sessions from disk into memory on startup.""" + for session_file in self.sessions_dir.glob("*.json"): + try: + with open(session_file, "r", encoding="utf-8") as f: + data = json.load(f) + session_id = data["session_id"] + self.sessions[session_id] = data["info"] + self.session_metrics[session_id] = data["metrics"] + except Exception: + # Skip corrupted files silently + continue + + def _write_worker(self): + """Background worker thread that handles asynchronous file writes.""" + while True: + try: + session_id = self._write_queue.get(timeout=1) + except queue.Empty: + continue + + if session_id is None: + self._write_queue.task_done() + break + + try: + with self._lock: + data = self._pending_writes.pop(session_id, None) + + if data is not None: + session_file = self.sessions_dir / f"{session_id}.json" + with open(session_file, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + except Exception: + pass + finally: + self._write_queue.task_done() + + def save_async(self, session_id: str): + """ + Queue a session for asynchronous saving to disk. + + Multiple rapid calls to save_async for the same session will be + coalesced into a single write operation. + + Args: + session_id: ID of the session to save + """ + if session_id not in self.sessions: + return + + data = { + "session_id": session_id, + "info": self.sessions[session_id], + "metrics": self.session_metrics[session_id], + } + + with self._lock: + self._pending_writes[session_id] = data + self._write_queue.put(session_id) + + def delete_session_files(self, session_id: str): + """ + Delete a session and all associated files. + + Args: + session_id: ID of the session to delete + """ + with self._lock: + self._pending_writes.pop(session_id, None) + + session_file = self.sessions_dir / f"{session_id}.json" + if session_file.exists(): + session_file.unlink() + + # Remove from in-memory cache + if session_id in self.sessions: + del self.sessions[session_id] + if session_id in self.session_metrics: + del self.session_metrics[session_id] + + def archive_session_files(self, session_id: str): + """ + Move a session from active to archived status. + + Flushes any pending async write before archiving so the archive + file is always created, even if the worker hasn't written the + active file to disk yet. + + Args: + session_id: ID of the session to archive + """ + with self._lock: + pending_data = self._pending_writes.pop(session_id, None) + + session_file = self.sessions_dir / f"{session_id}.json" + archive_file = self.archive_dir / f"{session_id}.json" + + if session_file.exists(): + session_file.rename(archive_file) + elif pending_data is not None: + # Pending write never flushed — write directly to archive. + with open(archive_file, "w", encoding="utf-8") as f: + json.dump(pending_data, f, indent=2) + + # Remove from in-memory cache + if session_id in self.sessions: + del self.sessions[session_id] + if session_id in self.session_metrics: + del self.session_metrics[session_id] + + def shutdown(self): + """Gracefully shut down the session store, flushing all pending writes.""" + # Wait for all pending writes to complete + self._write_queue.join() + # Send shutdown signal to worker thread + self._write_queue.put(None) + # Wait for worker thread to exit + self._write_thread.join(timeout=5) \ No newline at end of file diff --git a/cli/tests/conftest.py b/cli/tests/conftest.py new file mode 100644 index 0000000000..2c1f8a1410 --- /dev/null +++ b/cli/tests/conftest.py @@ -0,0 +1,59 @@ +"""Test configuration for the CLI test suite. + +Pre-mocks external dependencies (DeerFlowClient, SqliteSaver) so that +engine.py can be imported without pulling in the full LangGraph runtime. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from unittest.mock import MagicMock + +# Make 'cli' package importable from any working directory +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +# --------------------------------------------------------------------------- +# Pre-mock DeerFlowClient — each constructor call returns a fresh mock so +# per-session client identity checks work correctly. +# --------------------------------------------------------------------------- + + +def _make_deerflow_client(*args, **kwargs): + client = MagicMock() + client.stream.return_value = iter([]) + client.get_thread.return_value = {"checkpoints": []} + client.list_models.return_value = {"models": []} + client.list_skills.return_value = {"skills": []} + client.upload_files.return_value = {"message": "ok"} + client.list_uploads.return_value = {"count": 0, "files": []} + client.delete_upload.return_value = {"message": "deleted"} + client.get_memory.return_value = {"facts": []} + client.clear_memory.return_value = None + client.update_skill.return_value = None + return client + + +_mock_client = MagicMock() +_mock_client.side_effect = _make_deerflow_client + +sys.modules["deerflow.client"] = MagicMock(DeerFlowClient=_mock_client) + +# --------------------------------------------------------------------------- +# Pre-mock SqliteSaver — each from_conn_string call returns a fresh context +# manager so per-session checkpointer identity works correctly. +# --------------------------------------------------------------------------- + + +def _make_sqlite_cm(conn_string=None): + cm = MagicMock() + cm.__enter__.return_value = MagicMock() + return cm + + +_mock_sqlite_saver = MagicMock() +_mock_sqlite_saver.from_conn_string.side_effect = _make_sqlite_cm + +sys.modules["langgraph.checkpoint.sqlite"] = MagicMock(SqliteSaver=_mock_sqlite_saver) +sys.modules["langgraph.checkpoint"] = MagicMock() +sys.modules["langgraph"] = MagicMock() diff --git a/cli/tests/test_cli.py b/cli/tests/test_cli.py new file mode 100644 index 0000000000..c6a3f25a24 --- /dev/null +++ b/cli/tests/test_cli.py @@ -0,0 +1,481 @@ +"""Integration tests for cli.py — command parsing and user interaction.""" + +from __future__ import annotations + +import io +import sys +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# safe_input +# --------------------------------------------------------------------------- + +class TestSafeInput: + """Input handling with UTF-8 encoding recovery.""" + + def test_returns_stripped_input(self): + from cli.cli import safe_input + + with patch("cli.cli.input", return_value=" hello \n"): + result = safe_input("> ") + assert result == " hello " + + def test_handles_unicode_decode_error(self): + from cli.cli import safe_input + + call_count = [0] + + def broken_input(prompt): + call_count[0] += 1 + if call_count[0] == 1: + raise UnicodeDecodeError("utf-8", b"", 0, 1, "mock error") + return "recovered" + + with patch("cli.cli.input", side_effect=broken_input): + result = safe_input("> ") + assert result == "recovered" + + def test_handles_eof(self): + from cli.cli import safe_input + + with patch("cli.cli.input", side_effect=EOFError()): + result = safe_input("> ") + assert result == "" + + +# --------------------------------------------------------------------------- +# multi_line_input +# --------------------------------------------------------------------------- + +class TestMultiLineInput: + """Multi-line input mode with !end sentinel.""" + + def test_reads_until_end_sentinel(self): + from cli.cli import multi_line_input + + lines = iter(["line1", "line2", "!end"]) + + with patch("cli.cli.input", side_effect=lambda: next(lines)): + result = multi_line_input("Enter:") + + assert result == "line1\nline2" + + def test_handles_eof(self): + from cli.cli import multi_line_input + + with patch("cli.cli.input", side_effect=EOFError()): + result = multi_line_input("Enter:") + + assert result == "" + + def test_handles_unicode_decode_error(self): + from cli.cli import multi_line_input + + call_count = [0] + + def broken_input(): + call_count[0] += 1 + if call_count[0] == 1: + raise UnicodeDecodeError("utf-8", b"", 0, 1, "mock error") + return "!end" + + with patch("cli.cli.input", side_effect=broken_input): + result = multi_line_input("Enter:") + + assert result == "" + + +# --------------------------------------------------------------------------- +# main — command dispatch +# --------------------------------------------------------------------------- + +class TestMainCommandDispatch: + """Verify that !commands correctly delegate to engine methods.""" + + @pytest.fixture(autouse=True) + def _reset_singleton(self): + """Reset the engine singleton before each test.""" + from engine import DeerFlowProductionEngine + DeerFlowProductionEngine._instance = None + DeerFlowProductionEngine._initialized = False + yield + DeerFlowProductionEngine._instance = None + DeerFlowProductionEngine._initialized = False + + def _run_main(self, inputs: list[str], mock_engine: MagicMock, monkeypatch) -> None: + """Run main() with a fixed list of inputs, then raise KeyboardInterrupt to exit.""" + mock_engine.current_session_id = "test1234" + mock_engine.client = MagicMock() + mock_engine.client._model_name = "opus" + mock_engine.client.list_models.return_value = {"models": [{"name": "opus", "display_name": "Opus", "supports_thinking": True}]} + mock_engine.client.list_skills.return_value = {"skills": [{"name": "coding", "category": "dev", "enabled": True}]} + mock_engine.client.get_memory.return_value = {"facts": []} + mock_engine.list_uploads.return_value = {"count": 0, "files": []} + + input_iter = iter(inputs) + + def mock_input(prompt=""): + try: + return next(input_iter) + except StopIteration: + raise KeyboardInterrupt + + with patch("cli.cli.DeerFlowProductionEngine", return_value=mock_engine), \ + patch("cli.cli.safe_input", side_effect=mock_input), \ + patch("cli.cli.multi_line_input", return_value="multi line content"): + try: + from cli.cli import main + main() + except (KeyboardInterrupt, SystemExit): + pass + + # --- Session management --- + + def test_new_creates_session(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!new custom-id My Title", "!exit"], engine, monkeypatch) + engine.create_session.assert_called_with("custom-id", "My Title") + + def test_new_without_args(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!new", "!exit"], engine, monkeypatch) + engine.create_session.assert_called_with(None, None) + + def test_switch_session(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!switch other-session", "!exit"], engine, monkeypatch) + engine.switch_session.assert_called_with("other-session") + + def test_switch_missing_arg_shows_error(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!switch", "!exit"], engine, monkeypatch) + engine.switch_session.assert_not_called() + + def test_delete_session(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!delete session sid123", "!exit"], engine, monkeypatch) + engine.delete_session.assert_called_with("sid123") + + def test_rename_session(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!rename New Title", "!exit"], engine, monkeypatch) + engine.rename_session.assert_called_with("test1234", "New Title") + + def test_rename_missing_title(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!rename", "!exit"], engine, monkeypatch) + engine.rename_session.assert_not_called() + + def test_archive_session(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!archive sid123", "!exit"], engine, monkeypatch) + engine.archive_session.assert_called_with("sid123") + + def test_list_archives(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!archives", "!exit"], engine, monkeypatch) + engine.list_archives.assert_called_once() + + def test_restore_archive(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!restore sid123", "!exit"], engine, monkeypatch) + engine.restore_archive.assert_called_with("sid123") + + def test_list_sessions(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!sessions", "!exit"], engine, monkeypatch) + engine.list_sessions.assert_called_once() + + # --- Export --- + + def test_export_session(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!export", "!exit"], engine, monkeypatch) + engine.export_session_markdown.assert_called_once() + + def test_export_all_checkpoints(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!export_all", "!exit"], engine, monkeypatch) + engine.export_all_checkpoints.assert_called_once() + + # --- Search --- + + def test_search(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!search keyword", "!exit"], engine, monkeypatch) + engine.search_sessions.assert_called_with("keyword") + + def test_search_missing_keyword(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!search", "!exit"], engine, monkeypatch) + engine.search_sessions.assert_not_called() + + # --- Debugging --- + + def test_steps(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + engine.get_session_steps.return_value = [ + {"step": 1, "user_input": "Hello world this is a longer message for truncation test"}, + ] + self._run_main(["!steps", "!exit"], engine, monkeypatch) + engine.get_session_steps.assert_called_once() + + def test_steps_all(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + engine.get_all_checkpoint_steps.return_value = [ + {"checkpoint_id": "abc12345xx", "ts": "2024-01-01", "has_new_content": True}, + ] + self._run_main(["!steps_all", "!exit"], engine, monkeypatch) + engine.get_all_checkpoint_steps.assert_called_once() + + # --- File management --- + + def test_upload_file(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!upload /path/to/file.txt", "!exit"], engine, monkeypatch) + engine.upload_file.assert_called_with("/path/to/file.txt") + + def test_list_files(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!files", "!exit"], engine, monkeypatch) + engine.list_uploads.assert_called_once() + + def test_delete_file(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!delete myfile.txt", "!exit"], engine, monkeypatch) + engine.delete_upload.assert_called_with("myfile.txt") + + # --- Models & skills --- + + def test_list_models(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!models", "!exit"], engine, monkeypatch) + engine.client.list_models.assert_called_once() + + def test_use_model(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!use sonnet", "!exit"], engine, monkeypatch) + engine.switch_model.assert_called_with("sonnet") + + def test_list_skills(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!skills", "!exit"], engine, monkeypatch) + engine.client.list_skills.assert_called_once() + + def test_enable_skill(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!enable coding", "!exit"], engine, monkeypatch) + engine.enable_skill.assert_called_with("coding") + + def test_disable_skill(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!disable coding", "!exit"], engine, monkeypatch) + engine.disable_skill.assert_called_with("coding") + + # --- Runtime modes --- + + def test_plan_on(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!plan on", "!exit"], engine, monkeypatch) + engine.enable_plan_mode.assert_called_once() + + def test_plan_off(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!plan off", "!exit"], engine, monkeypatch) + engine.disable_plan_mode.assert_called_once() + + def test_plan_invalid(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!plan invalid", "!exit"], engine, monkeypatch) + engine.enable_plan_mode.assert_not_called() + engine.disable_plan_mode.assert_not_called() + + def test_subagent_on(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!subagent on", "!exit"], engine, monkeypatch) + engine.enable_subagent.assert_called_once() + + def test_subagent_off(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!subagent off", "!exit"], engine, monkeypatch) + engine.disable_subagent.assert_called_once() + + def test_subagent_invalid(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!subagent invalid", "!exit"], engine, monkeypatch) + engine.enable_subagent.assert_not_called() + engine.disable_subagent.assert_not_called() + + # --- Memory --- + + def test_memory(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!memory", "!exit"], engine, monkeypatch) + engine.client.get_memory.assert_called_once() + + def test_clear_memory(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!clear", "!exit"], engine, monkeypatch) + engine.client.clear_memory.assert_called_once() + + # --- Help / Exit --- + + def test_help(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + # help prints but does not delegate to any engine method + self._run_main(["!help", "!exit"], engine, monkeypatch) + # No engine methods should be called (just printing) + + def test_exit_breaks_loop(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + self._run_main(["!exit"], engine, monkeypatch) + # !exit causes main() to break out of the loop — no exception + + # --- Multi-line --- + + def test_multi_line_mode(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + engine.chat.return_value = iter(["response"]) + + self._run_main(["!multi", "!exit"], engine, monkeypatch) + engine.chat.assert_called_with("multi line content") + + def test_multi_line_empty_ignored(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + + with patch("cli.cli.DeerFlowProductionEngine", return_value=engine), \ + patch("cli.cli.safe_input", side_effect=["!multi", KeyboardInterrupt]), \ + patch("cli.cli.multi_line_input", return_value=""): + try: + from cli.cli import main + main() + except (KeyboardInterrupt, SystemExit): + pass + # Empty multi-line input should not call chat + engine.chat.assert_not_called() + + # --- Normal chat --- + + def test_default_chat_path(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = "test1234" + engine.chat.return_value = iter(["Hello!"]) + + self._run_main(["What is AI?", "!exit"], engine, monkeypatch) + engine.chat.assert_called_with("What is AI?") + + def test_creates_session_when_none_active(self, monkeypatch): + engine = MagicMock() + engine.current_session_id = None + engine.client = MagicMock() + engine.client._model_name = "opus" + engine.client.list_models.return_value = {"models": [{"name": "opus", "display_name": "Opus", "supports_thinking": True}]} + engine.client.list_skills.return_value = {"skills": []} + engine.client.get_memory.return_value = {"facts": []} + engine.list_uploads.return_value = {"count": 0, "files": []} + engine.chat.return_value = iter(["response"]) + + # create_session must set current_session_id so the loop doesn't crash + def _create(*args, **kwargs): + engine.current_session_id = "new12345" + return "new12345" + engine.create_session.side_effect = _create + + with patch("cli.cli.DeerFlowProductionEngine", return_value=engine), \ + patch("cli.cli.safe_input", side_effect=["Hello", "!exit"]): + try: + from cli.cli import main + main() + except (KeyboardInterrupt, SystemExit): + pass + + engine.create_session.assert_called_once() + + # --- Error handling --- + + def test_generic_exception_handling(self, monkeypatch): + """Exceptions in command handling are caught and printed, not propagated.""" + engine = MagicMock() + engine.current_session_id = "test1234" + engine.chat.side_effect = RuntimeError("Something went wrong") + + self._run_main(["broken", "!exit"], engine, monkeypatch) + # Exception is caught in the REPL loop — main() continues to !exit + + +# --------------------------------------------------------------------------- +# main — null session handling +# --------------------------------------------------------------------------- + +class TestMainNullSession: + """main() creates a session when current_session_id is None.""" + + @pytest.fixture(autouse=True) + def _reset_singleton(self): + from engine import DeerFlowProductionEngine + DeerFlowProductionEngine._instance = None + DeerFlowProductionEngine._initialized = False + yield + DeerFlowProductionEngine._instance = None + DeerFlowProductionEngine._initialized = False + + def test_null_session_triggers_create(self, monkeypatch): + """When current_session_id is None, main() calls create_session().""" + engine = MagicMock() + engine.current_session_id = None + + # create_session must set current_session_id so the loop works + def _create(*args, **kwargs): + engine.current_session_id = "new12345" + return "new12345" + engine.create_session.side_effect = _create + + with patch("cli.cli.DeerFlowProductionEngine", return_value=engine), \ + patch("cli.cli.safe_input", side_effect=["!sessions", KeyboardInterrupt]): + try: + from cli.cli import main + main() + except (KeyboardInterrupt, SystemExit): + pass + + engine.create_session.assert_called() diff --git a/cli/tests/test_engine.py b/cli/tests/test_engine.py new file mode 100644 index 0000000000..027e49d954 --- /dev/null +++ b/cli/tests/test_engine.py @@ -0,0 +1,1191 @@ +"""Unit tests for engine.py — session lifecycle, client management, and checkpoint switching.""" + +from __future__ import annotations + +import json +import os +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from engine import ( + SESSIONS_DIR, + ARCHIVE_DIR, + DeerFlowProductionEngine, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _reset_engine_singleton(): + """Destroy the singleton state and stop any active patches between tests.""" + yield + instance = DeerFlowProductionEngine._instance + if instance is not None and hasattr(instance, "_patchers"): + for p in reversed(instance._patchers): + p.stop() + DeerFlowProductionEngine._instance = None + DeerFlowProductionEngine._initialized = False + + +def _make_engine(mock_store, tmp_path: Path) -> DeerFlowProductionEngine: + """Construct an engine with SessionStore patched and dirs redirected. + + Patches stay active for the lifetime of the engine — the autouse + fixture ``_reset_engine_singleton`` stops them on teardown. + """ + mock_store.sessions = {} + mock_store.session_metrics = {} + mock_store.save_async = MagicMock() + mock_store.delete_session_files = MagicMock() + mock_store.archive_session_files = MagicMock() + mock_store.shutdown = MagicMock() + # Clear the singleton so every test gets a fresh engine + DeerFlowProductionEngine._instance = None + DeerFlowProductionEngine._initialized = False + + p_store = patch("engine.SessionStore", return_value=mock_store) + p_sessions = patch("engine.SESSIONS_DIR", tmp_path / "sessions") + p_archive = patch("engine.ARCHIVE_DIR", tmp_path / "archive") + + p_store.start() + p_sessions.start() + p_archive.start() + + engine = DeerFlowProductionEngine() + engine._patchers = [p_store, p_sessions, p_archive] + return engine + + +def _clear_all_sessions(engine: DeerFlowProductionEngine): + """Remove the default session auto-created by __init__.""" + engine.current_session_id = None + engine.store.sessions.clear() + engine.store.session_metrics.clear() + engine._clients.clear() + engine._checkpointer_cms.clear() + engine._checkpointers.clear() + + +def _prime_session(engine: DeerFlowProductionEngine, sid="s1", title="Test"): + """Register a session in the store and activate it.""" + engine.store.sessions[sid] = { + "created_at": time.time(), + "last_active": time.time(), + "title": title, + "last_checkpoint_id": None, + } + engine.store.session_metrics[sid] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + engine.current_session_id = sid + + +# --------------------------------------------------------------------------- +# Singleton +# --------------------------------------------------------------------------- + +class TestSingleton: + """DeerFlowProductionEngine must behave as a singleton.""" + + def test_same_instance_returned(self): + a = DeerFlowProductionEngine() + b = DeerFlowProductionEngine() + assert a is b + + def test_init_guards_against_reinit(self): + engine = DeerFlowProductionEngine() + original_store = engine.store + # Calling __init__ again must not overwrite store + engine.__init__() + assert engine.store is original_store + + +# --------------------------------------------------------------------------- +# _get_or_create_client — per-session client setup +# --------------------------------------------------------------------------- + +class TestGetOrCreateClient: + """Client creation and reuse for session isolation.""" + + def test_creates_new_client_for_unknown_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + + assert "s1" in engine._clients + assert engine._clients["s1"] is client + + def test_reuses_existing_client(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + c1 = engine._get_or_create_client("s1") + c2 = engine._get_or_create_client("s1") + + assert c1 is c2 + + def test_each_session_gets_own_client(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + _prime_session(engine, "s2") + + c1 = engine._get_or_create_client("s1") + c2 = engine._get_or_create_client("s2") + + assert c1 is not c2 + assert "s1" in engine._clients + assert "s2" in engine._clients + + def test_applies_runtime_settings_to_new_client(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + engine._runtime_settings["model_name"] = "opus" + engine._runtime_settings["plan_mode"] = True + engine._runtime_settings["thinking_enabled"] = False + + client = engine._get_or_create_client("s1") + + assert client._model_name == "opus" + assert client._plan_mode is True + assert client._thinking_enabled is False + + def test_does_not_reapply_settings_to_existing_client(self, tmp_path: Path): + """Settings are only applied on first creation, not on reuse.""" + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + engine._runtime_settings["model_name"] = "opus" + c1 = engine._get_or_create_client("s1") + + engine._runtime_settings["model_name"] = "sonnet" + c2 = engine._get_or_create_client("s1") + + assert c1 is c2 + assert c1._model_name == "opus" # unchanged from first creation + + +# --------------------------------------------------------------------------- +# client property +# --------------------------------------------------------------------------- + +class TestClientProperty: + """The client property returns the current session's client.""" + + def test_returns_none_when_no_current_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + assert engine.client is None + + def test_returns_client_for_current_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + engine._get_or_create_client("s1") + assert engine.client is engine._clients["s1"] + + def test_returns_none_when_session_has_no_client(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + engine.current_session_id = "orphan" + assert engine.client is None + + +# --------------------------------------------------------------------------- +# Session lifecycle +# --------------------------------------------------------------------------- + +class TestSessionLifecycle: + """CRUD operations on sessions.""" + + def test_create_session_assigns_uuid(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + sid = engine.create_session() + assert len(sid) == 32 # uuid4 hex + assert sid in engine.store.sessions + + def test_create_session_with_custom_id(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + sid = engine.create_session(session_id="my-session-42", title="Custom") + assert sid == "my-session-42" + assert engine.store.sessions["my-session-42"]["title"] == "Custom" + + def test_create_session_rejects_invalid_id(self, tmp_path: Path): + """Non-alphanumeric-underscore-dash IDs are replaced with uuid.""" + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + sid = engine.create_session(session_id="bad id!") + assert sid != "bad id!" + assert len(sid) == 32 + + def test_create_session_duplicate_id_returns_same(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + sid = engine.create_session(session_id="dup", title="First") + sid2 = engine.create_session(session_id="dup", title="Second") + assert sid == sid2 + # Title is not overwritten + assert engine.store.sessions["dup"]["title"] == "First" + + def test_switch_session_success(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + _prime_session(engine, "s2") + + result = engine.switch_session("s2") + + assert result is True + assert engine.current_session_id == "s2" + + def test_switch_session_not_found(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + result = engine.switch_session("nonexistent") + + assert result is False + assert engine.current_session_id == "s1" # unchanged + + def test_switch_session_updates_last_active(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + _prime_session(engine, "s2") + old_active = engine.store.sessions["s2"]["last_active"] + + engine.switch_session("s2") + + assert engine.store.sessions["s2"]["last_active"] > old_active + + def test_delete_session_removes_everything(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + _prime_session(engine, "s1") + + # Create a client for the session + engine._get_or_create_client("s1") + + # Make the mock actually remove from sessions dict + def _delete(sid): + engine.store.sessions.pop(sid, None) + store.delete_session_files.side_effect = _delete + + engine.delete_session("s1") + + assert "s1" not in engine.store.sessions + # _ensure_current_session creates a new default after deletion + assert engine.current_session_id is not None + assert engine.current_session_id != "s1" + store.delete_session_files.assert_called_once_with("s1") + + def test_delete_session_not_found(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + result = engine.delete_session("nonexistent") + + assert result is False + assert "s1" in engine.store.sessions # other sessions untouched + + def test_rename_session_success(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "Old") + + result = engine.rename_session("s1", "New Title") + + assert result is True + assert engine.store.sessions["s1"]["title"] == "New Title" + store.save_async.assert_called_with("s1") + + def test_rename_session_not_found(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + result = engine.rename_session("ghost", "Nope") + assert result is False + + def test_archive_session_moves_files(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir(parents=True, exist_ok=True) + archive_dir = tmp_path / "archive" + archive_dir.mkdir(parents=True, exist_ok=True) + + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + # Create a checkpoint db file to archive + db_path = sessions_dir / "s1_checkpoints.db" + db_path.write_text("fake-db") + engine._get_or_create_client("s1") + + engine.archive_session("s1") + + store.archive_session_files.assert_called_once_with("s1") + # DB should have been moved + assert not (sessions_dir / "s1_checkpoints.db").exists() + assert (archive_dir / "s1_checkpoints.db").exists() + + def test_restore_archive_success(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir(parents=True, exist_ok=True) + archive_dir = tmp_path / "archive" + archive_dir.mkdir(parents=True, exist_ok=True) + + # Write archive data + archive_data = { + "session_id": "arch1", + "info": { + "created_at": 1000.0, + "last_active": 2000.0, + "title": "Archived Session", + "last_checkpoint_id": None, + }, + "metrics": {"total_tokens": 10, "tool_calls": 1, "turns": 1}, + } + (archive_dir / "arch1.json").write_text(json.dumps(archive_data)) + + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + + result = engine.restore_archive("arch1") + + assert result is True + assert "arch1" in engine.store.sessions + assert engine.store.sessions["arch1"]["title"] == "Archived Session" + assert engine.current_session_id == "arch1" + + def test_restore_archive_not_found(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + + result = engine.restore_archive("missing") + + assert result is False + + def test_restore_archive_already_active(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + sessions_dir.mkdir(parents=True, exist_ok=True) + archive_dir = tmp_path / "archive" + archive_dir.mkdir(parents=True, exist_ok=True) + (archive_dir / "arch1.json").write_text( + json.dumps({ + "session_id": "arch1", + "info": {"created_at": 1.0, "last_active": 2.0, "title": "X", "last_checkpoint_id": None}, + "metrics": {"total_tokens": 0, "tool_calls": 0, "turns": 0}, + }) + ) + + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "arch1") # already active + + result = engine.restore_archive("arch1") + + assert result is False + + +# --------------------------------------------------------------------------- +# _ensure_current_session +# --------------------------------------------------------------------------- + +class TestEnsureCurrentSession: + """Automatic recovery when current_session_id becomes invalid.""" + + def test_falls_back_to_first_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + _prime_session(engine, "s1") + _prime_session(engine, "s2") + engine.current_session_id = "orphan" # not in store + + engine._ensure_current_session() + + assert engine.current_session_id == "s1" + + def test_creates_default_when_store_empty(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + engine.current_session_id = None + + engine._ensure_current_session() + + assert engine.current_session_id is not None + assert engine.current_session_id in engine.store.sessions + + +# --------------------------------------------------------------------------- +# _extract_steps — checkpoint-to-step parsing +# --------------------------------------------------------------------------- + +class TestExtractSteps: + """Parsing checkpoint history into structured conversation steps.""" + + def _make_thread_data(self, checkpoints: list[dict]) -> dict: + return {"checkpoints": checkpoints} + + def _make_cp(self, messages: list[dict], checkpoint_id="cp1", ts="2024-01-01"): + return { + "checkpoint_id": checkpoint_id, + "parent_checkpoint_id": "parent1", + "ts": ts, + "values": {"messages": messages, "total_tokens": 100}, + } + + def test_empty_checkpoints(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = self._make_thread_data([]) + + steps = engine._extract_steps("s1") + assert steps == [] + + def test_single_human_ai_turn(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = self._make_thread_data([ + self._make_cp([ + {"type": "human", "id": "h1", "content": "Hello", "metadata": {}}, + {"type": "ai", "id": "a1", "content": "Hi there!", "response_metadata": {"model": "opus"}}, + ]), + ]) + + steps = engine._extract_steps("s1") + + assert len(steps) == 1 + assert steps[0]["step"] == 1 + assert steps[0]["user_input"] == "Hello" + assert steps[0]["ai_response"] == "Hi there!" + assert steps[0]["ai_response_metadata"]["model"] == "opus" + + def test_detects_duplicate_messages_across_checkpoints(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = self._make_thread_data([ + self._make_cp([ + {"type": "human", "id": "h1", "content": "Q1", "metadata": {}}, + {"type": "ai", "id": "a1", "content": "A1", "response_metadata": {}}, + ], checkpoint_id="cp1"), + # cp2 has the same ai message again (duplicate) + self._make_cp([ + {"type": "human", "id": "h1", "content": "Q1", "metadata": {}}, + {"type": "ai", "id": "a1", "content": "A1", "response_metadata": {}}, + ], checkpoint_id="cp2"), + ]) + + steps = engine._extract_steps("s1") + + assert len(steps) == 1 # only one logical step + assert len(steps[0]["duplicate_messages"]) == 2 # h1 and a1 were dupes + + def test_tool_calls_and_results(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = self._make_thread_data([ + self._make_cp([ + {"type": "human", "id": "h1", "content": "Search X", "metadata": {}}, + { + "type": "ai", + "id": "a1", + "content": "", + "response_metadata": {}, + "tool_calls": [ + {"id": "tc1", "name": "search", "args": {"query": "X"}}, + ], + }, + { + "type": "tool", + "id": "t1", + "content": "Found 3 results", + "tool_call_id": "tc1", + }, + ]), + ]) + + steps = engine._extract_steps("s1") + + assert len(steps) == 1 + assert len(steps[0]["tool_calls"]) == 1 + assert steps[0]["tool_calls"][0]["name"] == "search" + assert steps[0]["tool_calls"][0]["result"] == "Found 3 results" + + def test_messages_without_id(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = self._make_thread_data([ + self._make_cp([ + {"type": "human", "content": "No ID message", "metadata": {}}, + ]), + ]) + + steps = engine._extract_steps("s1") + + assert len(steps) == 1 + assert steps[0]["user_input"] == "No ID message" + + def test_marks_duplicate_tool_calls(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = self._make_thread_data([ + self._make_cp([ + {"type": "human", "id": "h1", "content": "Q", "metadata": {}}, + {"type": "ai", "id": "a1", "content": "", "response_metadata": {}, "tool_calls": [ + {"id": "tc1", "name": "t1", "args": {}}, + ]}, + ]), + self._make_cp([ + {"type": "human", "id": "h2", "content": "Q2", "metadata": {}}, + {"type": "ai", "id": "a2", "content": "", "response_metadata": {}, "tool_calls": [ + {"id": "tc1", "name": "t1", "args": {}}, # same TC id = duplicate + ]}, + ]), + ]) + + steps = engine._extract_steps("s1") + + assert len(steps) == 2 + # The second step's tool call should be marked as duplicate + assert steps[1]["tool_calls"][0]["is_duplicate"] is True + + +# --------------------------------------------------------------------------- +# get_session_steps / get_all_checkpoint_steps +# --------------------------------------------------------------------------- + +class TestIntrospectionMethods: + """Read-only introspection using per-session clients.""" + + def test_get_session_steps_defaults_to_current(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + with patch.object(engine, "_extract_steps", return_value=[{"step": 1}]) as mock_extract: + result = engine.get_session_steps() + mock_extract.assert_called_once_with("s1") + assert result == [{"step": 1}] + + def test_get_session_steps_returns_empty_when_no_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + assert engine.get_session_steps() == [] + + def test_get_all_checkpoint_steps_basic(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = { + "checkpoints": [ + { + "checkpoint_id": "cp1", + "parent_checkpoint_id": None, + "ts": "2024-01-01", + "values": { + "messages": [ + {"type": "human", "id": "h1", "content": "Hello"}, + {"type": "ai", "id": "a1", "content": "Hi"}, + ], + }, + }, + { + "checkpoint_id": "cp2", + "parent_checkpoint_id": "cp1", + "ts": "2024-01-02", + "values": { + "messages": [ + {"type": "human", "id": "h1", "content": "Hello"}, # dupe + {"type": "ai", "id": "a1", "content": "Hi"}, # dupe + {"type": "human", "id": "h2", "content": "Follow-up"}, # new + ], + }, + }, + ], + } + + cps = engine.get_all_checkpoint_steps("s1") + + assert len(cps) == 2 + assert cps[0]["checkpoint_id"] == "cp1" + assert len(cps[0]["new_messages"]) == 2 + assert cps[1]["checkpoint_id"] == "cp2" + assert len(cps[1]["new_messages"]) == 1 + assert cps[1]["new_messages"][0]["content"] == "Follow-up" + assert cps[0]["has_new_content"] is True + + def test_get_all_checkpoint_steps_no_checkpoints(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + client = engine._get_or_create_client("s1") + client.get_thread.return_value = {"checkpoints": []} + + cps = engine.get_all_checkpoint_steps("s1") + assert cps == [] + + +# --------------------------------------------------------------------------- +# Chat +# --------------------------------------------------------------------------- + +class TestChat: + """Streaming chat and metrics tracking.""" + + def test_chat_creates_session_when_none_active(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + # chat() calls create_session when store is empty and no current session + with patch.object(engine, "create_session", wraps=engine.create_session) as spy: + list(engine.chat("hello")) + spy.assert_called_once() + + def test_chat_streams_response_chunks(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + from unittest.mock import PropertyMock + from types import SimpleNamespace + + client = engine._get_or_create_client("s1") + + Event = SimpleNamespace + client.stream.return_value = iter([ + Event(type="messages-tuple", data={"type": "ai", "content": "Hello"}), + Event(type="messages-tuple", data={"type": "ai", "content": " world"}), + Event(type="end", data={"usage": {"total_tokens": 50}}), + ]) + client.get_thread.return_value = {"checkpoints": [{"checkpoint_id": "cp1"}]} + + chunks = list(engine.chat("Hi")) + + assert "Hello" in chunks + assert " world" in chunks + # Metrics line at the end + assert any("50" in c for c in chunks if isinstance(c, str)) + + def test_chat_increments_metrics(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + engine.store.session_metrics["s1"]["turns"] = 0 + engine.store.session_metrics["s1"]["total_tokens"] = 0 + + from types import SimpleNamespace + + Event = SimpleNamespace + client = engine._get_or_create_client("s1") + client.stream.return_value = iter([ + Event(type="messages-tuple", data={"type": "ai", "content": "A"}), + Event(type="messages-tuple", data={"type": "ai", "content": "B", "tool_calls": [{"id": "t1"}]}), + Event(type="end", data={"usage": {"total_tokens": 30}}), + ]) + client.get_thread.return_value = {"checkpoints": [{"checkpoint_id": "cp1"}]} + + list(engine.chat("Q")) + + assert engine.store.session_metrics["s1"]["turns"] == 1 + assert engine.store.session_metrics["s1"]["total_tokens"] == 30 + assert engine.store.session_metrics["s1"]["tool_calls"] == 1 + + def test_chat_updates_title_on_first_turn(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "New Session") # default title + + from types import SimpleNamespace + + Event = SimpleNamespace + client = engine._get_or_create_client("s1") + client.stream.return_value = iter([ + Event(type="messages-tuple", data={"type": "ai", "content": "A long response about weather"}), + Event(type="end", data={"usage": {"total_tokens": 10}}), + ]) + client.get_thread.return_value = {"checkpoints": [{"checkpoint_id": "cp1"}]} + + list(engine.chat("What is the weather today?")) + + # Title should be truncated to first 30 chars of user message + assert engine.store.sessions["s1"]["title"] == "What is the weather today?" + + def test_chat_does_not_overwrite_custom_title(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "My Custom Title") + + from types import SimpleNamespace + + Event = SimpleNamespace + client = engine._get_or_create_client("s1") + client.stream.return_value = iter([ + Event(type="messages-tuple", data={"type": "ai", "content": "OK"}), + Event(type="end", data={"usage": {"total_tokens": 5}}), + ]) + client.get_thread.return_value = {"checkpoints": [{"checkpoint_id": "cp1"}]} + + list(engine.chat("Another message")) + + assert engine.store.sessions["s1"]["title"] == "My Custom Title" + + +# --------------------------------------------------------------------------- +# Runtime controls +# --------------------------------------------------------------------------- + +class TestRuntimeControls: + """Switching model, plan mode, subagent, and skills.""" + + def test_switch_model_success(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + client = engine._get_or_create_client("s1") + client.list_models.return_value = {"models": [{"name": "opus"}, {"name": "sonnet"}]} + + result = engine.switch_model("opus") + + assert result is True + assert engine._runtime_settings["model_name"] == "opus" + + def test_switch_model_not_found(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + client = engine._get_or_create_client("s1") + client.list_models.return_value = {"models": [{"name": "sonnet"}]} + + result = engine.switch_model("nonexistent") + + assert result is False + + def test_switch_model_no_client(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + assert engine.switch_model("opus") is False + + def test_enable_disable_plan_mode(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + client = engine._get_or_create_client("s1") + + engine.enable_plan_mode() + assert client._plan_mode is True + assert engine._runtime_settings["plan_mode"] is True + + engine.disable_plan_mode() + assert client._plan_mode is False + assert engine._runtime_settings["plan_mode"] is False + + def test_enable_disable_subagent(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + client = engine._get_or_create_client("s1") + + engine.enable_subagent() + assert client._subagent_enabled is True + assert engine._runtime_settings["subagent_enabled"] is True + + engine.disable_subagent() + assert client._subagent_enabled is False + + def test_enable_disable_skill_no_client(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + assert engine.enable_skill("s") is False + assert engine.disable_skill("s") is False + + def test_enable_skill_success(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + engine._get_or_create_client("s1") + result = engine.enable_skill("coding") + assert result is True + + def test_runtime_settings_persisted_across_sessions(self, tmp_path: Path): + """Settings survive across session switches because _runtime_settings + is applied to each new client.""" + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + engine._get_or_create_client("s1") + + engine._runtime_settings["model_name"] = "haiku" + engine._runtime_settings["plan_mode"] = True + + # Create a second session — it should inherit the settings + _prime_session(engine, "s2") + client2 = engine._get_or_create_client("s2") + + assert client2._model_name == "haiku" + assert client2._plan_mode is True + + +# --------------------------------------------------------------------------- +# Shutdown +# --------------------------------------------------------------------------- + +class TestShutdown: + """Graceful shutdown releases all resources.""" + + def test_shutdown_calls_store_shutdown(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + engine.shutdown() + store.shutdown.assert_called_once() + + def test_shutdown_destroys_all_clients(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + _prime_session(engine, "s2") + engine._get_or_create_client("s1") + engine._get_or_create_client("s2") + + engine.shutdown() + + assert "s1" not in engine._clients + assert "s2" not in engine._clients + assert "s1" not in engine._checkpointer_cms + assert "s2" not in engine._checkpointer_cms + + +# --------------------------------------------------------------------------- +# Export +# --------------------------------------------------------------------------- + +class TestExport: + """Markdown export of sessions and checkpoints.""" + + def test_export_session_markdown_no_active_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + result = engine.export_session_markdown() + assert result is None + + def test_export_session_markdown_creates_file(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "Export Test") + engine.store.session_metrics["s1"]["total_tokens"] = 10 + + with patch.object(engine, "get_session_steps", return_value=[ + {"step": 1, "user_input": "Q", "ai_response": "A", "tool_calls": [], "duplicate_messages": []} + ]): + result = engine.export_session_markdown("s1") + + assert result is not None + assert os.path.exists(result) + content = Path(result).read_text() + assert "# Export Test" in content + assert "Q" in content + assert "A" in content + + def test_export_all_checkpoints_no_active_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + result = engine.export_all_checkpoints() + assert result is None + + def test_export_all_checkpoints_creates_file(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "CP Export") + engine.store.session_metrics["s1"]["total_tokens"] = 20 + + with patch.object(engine, "get_all_checkpoint_steps", return_value=[ + { + "checkpoint_id": "cp1", + "parent_checkpoint_id": None, + "ts": "2024-01-01", + "new_messages": [{"type": "human", "content": "Hello"}], + "has_new_content": True, + }, + ]): + result = engine.export_all_checkpoints("s1") + + assert result is not None + assert os.path.exists(result) + content = Path(result).read_text() + assert "# CP Export (All Checkpoints)" in content + assert "Hello" in content + + def test_export_session_markdown_with_tool_calls(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "Tool Test") + + with patch.object(engine, "get_session_steps", return_value=[ + { + "step": 1, + "user_input": "Search X", + "ai_response": "Results:", + "tool_calls": [ + { + "id": "tc1", + "name": "search", + "args": {"query": "X"}, + "result": '{"hits": 5}', + "is_duplicate": False, + }, + ], + "duplicate_messages": [], + }, + ]): + result = engine.export_session_markdown("s1") + + content = Path(result).read_text() + assert "search" in content + assert "hits" in content + + def test_export_all_checkpoints_with_no_new_content(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "Empty CP") + + with patch.object(engine, "get_all_checkpoint_steps", return_value=[ + { + "checkpoint_id": "cp1", + "parent_checkpoint_id": None, + "ts": "2024-01-01", + "new_messages": [], + "has_new_content": False, + }, + ]): + result = engine.export_all_checkpoints("s1") + + content = Path(result).read_text() + assert "no new messages" in content.lower() + + +# --------------------------------------------------------------------------- +# Search +# --------------------------------------------------------------------------- + +class TestSearch: + """Keyword search across sessions.""" + + def test_search_finds_keyword_in_user_input(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "Session One") + engine.store.sessions = {"s1": {"title": "Session One"}} + engine.store.session_metrics["s1"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + + with patch.object(engine, "get_session_steps", return_value=[ + {"step": 1, "user_input": "I love Python programming", "ai_response": "That's great!"}, + ]): + engine.search_sessions("Python") + + # No exception -> test passes; search uses print() + + def test_search_finds_keyword_in_ai_response(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + engine.store.sessions = {"s1": {"title": "Session One"}} + engine.store.session_metrics["s1"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + + with patch.object(engine, "get_session_steps", return_value=[ + {"step": 1, "user_input": "Hello", "ai_response": "I recommend using pytest"}, + ]): + engine.search_sessions("pytest") + + def test_search_no_match(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + engine.store.sessions = {"s1": {"title": "X"}} + engine.store.session_metrics["s1"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + + with patch.object(engine, "get_session_steps", return_value=[ + {"step": 1, "user_input": "Hello", "ai_response": "Hi"}, + ]): + engine.search_sessions("zzznotfound") + + +# --------------------------------------------------------------------------- +# File upload operations +# --------------------------------------------------------------------------- + +class TestFileOperations: + """Upload, list, and delete files.""" + + def test_upload_file_no_active_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + result = engine.upload_file("/some/file") + assert result is None + + def test_upload_file_not_found(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + result = engine.upload_file("/nonexistent/path.txt") + assert result is None + + def test_upload_file_success(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1") + + test_file = tmp_path / "test.txt" + test_file.write_text("content") + + client = engine._get_or_create_client("s1") + client.upload_files.return_value = {"message": "Uploaded"} + + result = engine.upload_file(str(test_file)) + assert result == {"message": "Uploaded"} + + def test_list_uploads_no_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + assert engine.list_uploads() is None + + def test_delete_upload_no_session(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _clear_all_sessions(engine) + assert engine.delete_upload("file.txt") is None + + +# --------------------------------------------------------------------------- +# list_sessions / list_archives +# --------------------------------------------------------------------------- + +class TestListing: + """List sessions and archives (output-only, no return value).""" + + def test_list_sessions(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + _prime_session(engine, "s1", "First") + _prime_session(engine, "s2", "Second") + engine.list_sessions() # exercises print paths + # No assertion needed — no exception is success + + def test_list_archives_empty(self, tmp_path: Path): + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + engine.list_archives() + + def test_list_archives_with_files(self, tmp_path: Path): + archive_dir = tmp_path / "archive" + archive_dir.mkdir(parents=True, exist_ok=True) + (archive_dir / "arch1.json").write_text("{}") + (archive_dir / "arch2.json").write_text("{}") + + store = MagicMock() + store.sessions = {} + engine = _make_engine(store, tmp_path) + engine.list_archives() diff --git a/cli/tests/test_session_store.py b/cli/tests/test_session_store.py new file mode 100644 index 0000000000..398294f5ce --- /dev/null +++ b/cli/tests/test_session_store.py @@ -0,0 +1,372 @@ +"""Unit tests for session_store.py — async persistence and file operations.""" + +from __future__ import annotations + +import json +import threading +import time +from pathlib import Path +from unittest.mock import patch + +import pytest + +from session_store import SessionStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _write_session_json(sessions_dir: Path, session_id: str, title="Test Session"): + """Write a well-formed session JSON to disk for load-on-startup tests.""" + data = { + "session_id": session_id, + "info": { + "created_at": 1000000.0, + "last_active": 1000001.0, + "title": title, + "last_checkpoint_id": None, + }, + "metrics": {"total_tokens": 42, "tool_calls": 3, "turns": 1}, + } + path = sessions_dir / f"{session_id}.json" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(data)) + return path + + +def _join_worker(store: SessionStore, timeout=2.0): + """Signal the worker thread to stop and join it.""" + store.shutdown() + store._write_thread.join(timeout=timeout) + + +# --------------------------------------------------------------------------- +# __init__ and disk loading +# --------------------------------------------------------------------------- + +class TestInitAndDiskLoading: + """Startup behavior: directory creation, session loading from disk.""" + + def test_creates_directories(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + assert sessions_dir.exists() + assert archive_dir.exists() + finally: + _join_worker(store) + + def test_loads_valid_sessions_from_disk(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + _write_session_json(sessions_dir, "abc123", "Hello") + + store = SessionStore(sessions_dir, archive_dir) + try: + assert "abc123" in store.sessions + assert store.sessions["abc123"]["title"] == "Hello" + assert store.session_metrics["abc123"]["total_tokens"] == 42 + finally: + _join_worker(store) + + def test_skips_corrupted_json_files(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + bad_file = sessions_dir / "bad.json" + bad_file.parent.mkdir(parents=True, exist_ok=True) + bad_file.write_text("not json at all {{{") + + store = SessionStore(sessions_dir, archive_dir) + try: + assert "bad" not in store.sessions + finally: + _join_worker(store) + + def test_empty_sessions_dir_starts_clean(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + assert store.sessions == {} + finally: + _join_worker(store) + + +# --------------------------------------------------------------------------- +# save_async — enqueue writes +# --------------------------------------------------------------------------- + +class TestSaveAsync: + """Asynchronous save behaviour.""" + + def test_queues_write_for_known_session(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + store.sessions["s1"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "T", + "last_checkpoint_id": None, + } + store.session_metrics["s1"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + + store.save_async("s1") + + # The pending-writes dict holds the coalesced data + with store._lock: + assert "s1" in store._pending_writes + + # Let the worker flush + _join_worker(store) + assert (sessions_dir / "s1.json").exists() + finally: + _join_worker(store) + + def test_noop_for_unknown_session(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + store.save_async("nonexistent") + with store._lock: + assert "nonexistent" not in store._pending_writes + finally: + _join_worker(store) + + def test_coalesces_multiple_rapid_calls(self, tmp_path: Path): + """Multiple save_async calls for the same session produce one write.""" + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + store.sessions["s2"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "Before", + "last_checkpoint_id": None, + } + store.session_metrics["s2"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + + store.save_async("s2") + # Mutate after first enqueue — the pending write snapshot is stale + store.sessions["s2"]["title"] = "After" + store.save_async("s2") + + _join_worker(store) + + # The second call overwrote the pending write, so disk gets "After" + data = json.loads((sessions_dir / "s2.json").read_text()) + assert data["info"]["title"] == "After" + finally: + _join_worker(store) + + +# --------------------------------------------------------------------------- +# _write_worker — background thread behavior +# --------------------------------------------------------------------------- + +class TestWriteWorker: + """Background thread lifecycle and error handling.""" + + def test_exits_on_none_sentinel(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + store.shutdown() + store._write_thread.join(timeout=3) + assert not store._write_thread.is_alive() + + def test_drains_queue_on_shutdown(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + store.sessions["s3"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "Q", + "last_checkpoint_id": None, + } + store.session_metrics["s3"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + + for _ in range(5): + store.save_async("s3") + + _join_worker(store) + + # All queued items were consumed (queue is empty) and file was written + assert (sessions_dir / "s3.json").exists() + # Queue should be empty after successful shutdown + assert store._write_queue.empty() + + +# --------------------------------------------------------------------------- +# delete_session_files +# --------------------------------------------------------------------------- + +class TestDeleteSessionFiles: + """Session file deletion and in-memory cleanup.""" + + def test_removes_file_and_memory_state(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + # Prime session in memory and on disk + store.sessions["d1"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "D", + "last_checkpoint_id": None, + } + store.session_metrics["d1"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + store.save_async("s1") # Just to exercise queue + (sessions_dir / "d1.json").write_text("{}") + + store.delete_session_files("d1") + + assert "d1" not in store.sessions + assert "d1" not in store.session_metrics + assert not (sessions_dir / "d1.json").exists() + finally: + _join_worker(store) + + def test_clears_pending_write(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + store.sessions["d2"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "D2", + "last_checkpoint_id": None, + } + store.session_metrics["d2"] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + store.save_async("d2") + + store.delete_session_files("d2") + + with store._lock: + assert "d2" not in store._pending_writes + finally: + _join_worker(store) + + +# --------------------------------------------------------------------------- +# archive_session_files +# --------------------------------------------------------------------------- + +class TestArchiveSessionFiles: + """Moving sessions to the archive directory.""" + + def test_moves_existing_file_to_archive(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + store.sessions["a1"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "A", + "last_checkpoint_id": None, + } + store.session_metrics["a1"] = {"total_tokens": 1, "tool_calls": 0, "turns": 1} + # Write to disk first so the file exists + store.save_async("a1") + # Drain so it's on disk + _join_worker(store) + # Recreate store with a fresh worker (the old one is now dead) + store = SessionStore(sessions_dir, archive_dir) + store.sessions["a1"] = { + "created_at": 1.0, + "last_active": 2.0, + "title": "A", + "last_checkpoint_id": None, + } + store.session_metrics["a1"] = {"total_tokens": 1, "tool_calls": 0, "turns": 1} + + assert (sessions_dir / "a1.json").exists() + + store.archive_session_files("a1") + + assert "a1" not in store.sessions + assert "a1" not in store.session_metrics + assert not (sessions_dir / "a1.json").exists() + assert (archive_dir / "a1.json").exists() + finally: + _join_worker(store) + + def test_handles_unflushed_pending_write(self, tmp_path: Path): + """When archiving a session that was saved but not yet flushed to disk.""" + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + try: + store.sessions["a2"] = { + "created_at": 1.0, + "last_active": 3.0, + "title": "Unflushed", + "last_checkpoint_id": None, + } + store.session_metrics["a2"] = {"total_tokens": 5, "tool_calls": 1, "turns": 1} + store.save_async("a2") + # Do NOT drain — the file isn't on disk yet + + store.archive_session_files("a2") + + assert "a2" not in store.sessions + assert not (sessions_dir / "a2.json").exists() + # Archive file was written directly from pending data + assert (archive_dir / "a2.json").exists() + + data = json.loads((archive_dir / "a2.json").read_text()) + assert data["info"]["title"] == "Unflushed" + finally: + _join_worker(store) + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + +class TestThreadSafety: + """Concurrent access to SessionStore from multiple threads.""" + + def test_concurrent_saves_from_multiple_threads(self, tmp_path: Path): + sessions_dir = tmp_path / "sessions" + archive_dir = tmp_path / "archive" + store = SessionStore(sessions_dir, archive_dir) + errors: list[Exception] = [] + + def writer(sid: str): + try: + store.sessions[sid] = { + "created_at": time.time(), + "last_active": time.time(), + "title": f"Thread-{sid}", + "last_checkpoint_id": None, + } + store.session_metrics[sid] = {"total_tokens": 0, "tool_calls": 0, "turns": 0} + for _ in range(10): + store.save_async(sid) + except Exception as e: + errors.append(e) + + threads = [] + for i in range(5): + t = threading.Thread(target=writer, args=(f"t{i}",)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + _join_worker(store) + + assert len(errors) == 0 + for i in range(5): + assert (sessions_dir / f"t{i}.json").exists()