diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 8baecb3631..9572e28841 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -21,6 +21,7 @@ memory, models, runs, + shares, skills, suggestions, thread_runs, @@ -351,6 +352,9 @@ def create_app() -> FastAPI: # Thread cleanup API is mounted at /api/threads/{thread_id} app.include_router(threads.router) + # Public conversation shares are mounted at /api/shares + app.include_router(shares.router) + # Agents API is mounted at /api/agents app.include_router(agents.router) diff --git a/backend/app/gateway/auth_middleware.py b/backend/app/gateway/auth_middleware.py index 6b64522643..03f0c512ab 100644 --- a/backend/app/gateway/auth_middleware.py +++ b/backend/app/gateway/auth_middleware.py @@ -49,6 +49,15 @@ def _is_public(path: str) -> bool: return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES) +def _is_public_request(method: str, path: str) -> bool: + """Return True for routes that are intentionally anonymous.""" + if _is_public(path): + return True + stripped = path.rstrip("/") + parts = stripped.split("/") + return method == "GET" and len(parts) == 4 and parts[:3] == ["", "api", "shares"] and bool(parts[3]) and parts[3] != "threads" + + class AuthMiddleware(BaseHTTPMiddleware): """Strict auth gate: reject requests without a valid session. @@ -73,7 +82,7 @@ def __init__(self, app: ASGIApp) -> None: super().__init__(app) async def dispatch(self, request: Request, call_next: Callable) -> Response: - if _is_public(request.url.path): + if _is_public_request(request.method, request.url.path): return await call_next(request) internal_user = None diff --git a/backend/app/gateway/routers/__init__.py b/backend/app/gateway/routers/__init__.py index c5f67a396b..2dbb38047f 100644 --- a/backend/app/gateway/routers/__init__.py +++ b/backend/app/gateway/routers/__init__.py @@ -1,3 +1,3 @@ -from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads +from . import artifacts, assistants_compat, mcp, models, shares, skills, suggestions, thread_runs, threads, uploads -__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"] +__all__ = ["artifacts", "assistants_compat", "mcp", "models", "shares", "skills", "suggestions", "threads", "thread_runs", "uploads"] diff --git a/backend/app/gateway/routers/shares.py b/backend/app/gateway/routers/shares.py new file mode 100644 index 0000000000..e4df5d62cf --- /dev/null +++ b/backend/app/gateway/routers/shares.py @@ -0,0 +1,311 @@ +"""Public conversation share endpoints.""" + +from __future__ import annotations + +import logging +import secrets +from datetime import UTC, datetime, timedelta +from typing import Any + +from fastapi import APIRouter, HTTPException, Request, Response +from pydantic import BaseModel, Field + +from app.gateway.authz import get_auth_context, require_permission +from app.gateway.deps import get_checkpointer, get_store, get_thread_store +from app.gateway.utils import sanitize_log_param +from deerflow.runtime import serialize_channel_values +from deerflow.utils.time import now_iso + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api/shares", tags=["shares"]) + +_SHARES_NS = ("shares",) +_SHARE_ID_BYTES = 16 +_SHARE_RETENTION = timedelta(days=30) +_SHARE_TTL_MINUTES = _SHARE_RETENTION.total_seconds() / 60 +_EXPIRED_SHARE_CLEANUP_BATCH_SIZE = 100 +_EXPIRED_SHARE_CLEANUP_MAX_BATCHES = 10 +_PUBLIC_LINK_VISIBILITY = "public_link" +_REVOKED_VISIBILITY = "revoked" + + +class ShareCreateRequest(BaseModel): + """Request body for creating a public share snapshot.""" + + message_ids: list[str] = Field( + min_length=1, + description="Message IDs to include in the public share.", + ) + title: str | None = Field(default=None, max_length=256, description="Optional share title") + + +class ShareCreateResponse(BaseModel): + share_id: str + title: str | None = None + created_at: str + + +class ShareResponse(BaseModel): + share_id: str + title: str | None = None + messages: list[dict[str, Any]] = Field(default_factory=list) + created_at: str + + +def _parse_iso_datetime(value: Any) -> datetime | None: + if not isinstance(value, str): + return None + try: + parsed = datetime.fromisoformat(value) + except ValueError: + return None + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) + + +def _is_expired_share(value: dict[str, Any], *, now: datetime | None = None) -> bool: + expires_at = _parse_iso_datetime(value.get("expires_at")) + if expires_at is None: + return False + return expires_at <= (now or datetime.now(UTC)) + + +def _get_request_user_id(request: Request) -> str: + auth = get_auth_context(request) + if auth is None: + raise HTTPException(status_code=401, detail="Authentication required") + return str(auth.require_user().id) + + +async def _require_explicit_thread_owner(request: Request, thread_id: str, user_id: str) -> None: + thread_store = get_thread_store(request) + record = await thread_store.get(thread_id, user_id=None) + if record is None or record.get("user_id") != user_id: + raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + + +def _extract_message_id(message: dict[str, Any]) -> str | None: + message_id = message.get("id") + return message_id if isinstance(message_id, str) and message_id else None + + +def _has_displayable_content(message: dict[str, Any]) -> bool: + content = message.get("content") + if isinstance(content, str): + return bool(content.strip()) + if isinstance(content, list): + return len(content) > 0 + return content is not None + + +def _is_shareable_message(message: dict[str, Any]) -> bool: + message_type = message.get("type") + if message_type == "human": + return _has_displayable_content(message) + if message_type == "ai": + has_tool_metadata = bool(message.get("tool_calls") or message.get("invalid_tool_calls")) + return _has_displayable_content(message) and not has_tool_metadata + return False + + +def _to_public_message(message: dict[str, Any]) -> dict[str, Any]: + """Keep only fields needed to render a public read-only message.""" + public_message: dict[str, Any] = { + "type": message.get("type"), + "content": message.get("content"), + } + message_id = _extract_message_id(message) + if message_id is not None: + public_message["id"] = message_id + return public_message + + +async def _put_unique_share(store, value: dict[str, Any]) -> str: + await _delete_expired_shares(store) + ttl = _SHARE_TTL_MINUTES if getattr(store, "supports_ttl", False) else None + for _ in range(4): + share_id = secrets.token_urlsafe(_SHARE_ID_BYTES) + if await store.aget(_SHARES_NS, share_id) is None: + if ttl is None: + await store.aput(_SHARES_NS, share_id, value) + else: + await store.aput(_SHARES_NS, share_id, value, ttl=ttl) + return share_id + raise HTTPException(status_code=500, detail="Failed to create share") + + +async def _delete_expired_shares(store) -> None: + if getattr(store, "supports_ttl", False): + return + try: + expired_items: list[tuple[tuple[str, ...], str]] = [] + now = datetime.now(UTC) + for batch_index in range(_EXPIRED_SHARE_CLEANUP_MAX_BATCHES): + items = await store.asearch( + _SHARES_NS, + limit=_EXPIRED_SHARE_CLEANUP_BATCH_SIZE, + offset=batch_index * _EXPIRED_SHARE_CLEANUP_BATCH_SIZE, + refresh_ttl=False, + ) + for item in items: + if _is_expired_share(item.value or {}, now=now): + expired_items.append((tuple(item.namespace), item.key)) + if len(items) < _EXPIRED_SHARE_CLEANUP_BATCH_SIZE: + break + for namespace, key in expired_items: + await store.adelete(namespace, key) + except Exception: + logger.debug("Failed to cleanup expired share snapshots", exc_info=True) + + +@router.post("/threads/{thread_id}", response_model=ShareCreateResponse) +@require_permission("threads", "read", owner_check=True, require_existing=True) +async def create_thread_share(thread_id: str, body: ShareCreateRequest, request: Request) -> ShareCreateResponse: + """Create a public immutable snapshot from an owned thread.""" + store = get_store(request) + if store is None: + raise HTTPException(status_code=503, detail="Store not available") + + user_id = _get_request_user_id(request) + await _require_explicit_thread_owner(request, thread_id, user_id) + + checkpointer = get_checkpointer(request) + if checkpointer is None: + raise HTTPException(status_code=503, detail="Checkpointer not available") + + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + try: + checkpoint_tuple = await checkpointer.aget_tuple(config) + except Exception: + logger.exception( + "Failed to get state for share source thread %s", + sanitize_log_param(thread_id), + ) + raise HTTPException(status_code=500, detail="Failed to create share") + + if checkpoint_tuple is None: + raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + + checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} + channel_values = checkpoint.get("channel_values", {}) or {} + serialized_values = serialize_channel_values(channel_values) + all_messages = serialized_values.get("messages", []) + if not isinstance(all_messages, list) or not all_messages: + raise HTTPException(status_code=400, detail="Thread has no messages to share") + + requested_ids = [message_id for message_id in body.message_ids if message_id] + if not requested_ids: + raise HTTPException(status_code=400, detail="No message IDs selected") + + requested_id_set = set(requested_ids) + selected_messages: list[dict[str, Any]] = [] + selected_id_set: set[str] = set() + for message in all_messages: + if not isinstance(message, dict): + continue + message_id = _extract_message_id(message) + if message_id in requested_id_set: + selected_messages.append(message) + selected_id_set.add(message_id) + + missing_ids = [message_id for message_id in requested_ids if message_id not in selected_id_set] + if missing_ids: + raise HTTPException( + status_code=400, + detail=f"Message IDs not found: {', '.join(missing_ids)}", + ) + + non_shareable_ids: list[str] = [] + for message in selected_messages: + message_id = _extract_message_id(message) + if message_id is not None and not _is_shareable_message(message): + non_shareable_ids.append(message_id) + if non_shareable_ids: + raise HTTPException( + status_code=400, + detail=f"Message IDs are not shareable: {', '.join(non_shareable_ids)}", + ) + + created_at = now_iso() + expires_at = (datetime.now(UTC) + _SHARE_RETENTION).isoformat() + title = serialized_values.get("title") if body.title is None else body.title + if not isinstance(title, str): + title = None + + share_id = await _put_unique_share( + store, + { + "title": title, + "messages": [_to_public_message(message) for message in selected_messages], + "source_thread_id": thread_id, + "created_by_user_id": user_id, + "message_ids": requested_ids, + "visibility": _PUBLIC_LINK_VISIBILITY, + "granted_at": created_at, + "created_at": created_at, + "expires_at": expires_at, + }, + ) + return ShareCreateResponse(share_id=share_id, title=title, created_at=created_at) + + +@router.get("/{share_id}", response_model=ShareResponse) +async def get_share(share_id: str, request: Request) -> ShareResponse: + """Read a public share snapshot without requiring authentication.""" + store = get_store(request) + if store is None: + raise HTTPException(status_code=503, detail="Store not available") + + item = await store.aget(_SHARES_NS, share_id) + if item is None: + raise HTTPException(status_code=404, detail="Share not found") + + value = item.value or {} + if value.get("visibility") == _REVOKED_VISIBILITY or value.get("revoked_at"): + raise HTTPException(status_code=404, detail="Share not found") + if _is_expired_share(value): + await store.adelete(_SHARES_NS, share_id) + raise HTTPException(status_code=404, detail="Share not found") + + messages = value.get("messages", []) + if not isinstance(messages, list): + messages = [] + public_messages: list[dict[str, Any]] = [] + for message in messages: + if isinstance(message, dict) and _is_shareable_message(message): + public_messages.append(_to_public_message(message)) + title = value.get("title") + return ShareResponse( + share_id=share_id, + title=title if isinstance(title, str) else None, + messages=public_messages, + created_at=value.get("created_at", ""), + ) + + +@router.delete("/{share_id}", status_code=204) +async def revoke_share(share_id: str, request: Request) -> Response: + """Revoke a public share link created by the current user.""" + store = get_store(request) + if store is None: + raise HTTPException(status_code=503, detail="Store not available") + + user_id = _get_request_user_id(request) + item = await store.aget(_SHARES_NS, share_id) + if item is None: + raise HTTPException(status_code=404, detail="Share not found") + + value = item.value or {} + if value.get("created_by_user_id") != user_id or _is_expired_share(value): + raise HTTPException(status_code=404, detail="Share not found") + + value = dict(value) + value["visibility"] = _REVOKED_VISIBILITY + value["revoked_at"] = now_iso() + ttl = _SHARE_TTL_MINUTES if getattr(store, "supports_ttl", False) else None + if ttl is None: + await store.aput(_SHARES_NS, share_id, value) + else: + await store.aput(_SHARES_NS, share_id, value, ttl=ttl) + return Response(status_code=204) diff --git a/backend/tests/test_auth_middleware.py b/backend/tests/test_auth_middleware.py index 726786ac9c..fea4343865 100644 --- a/backend/tests/test_auth_middleware.py +++ b/backend/tests/test_auth_middleware.py @@ -3,7 +3,7 @@ import pytest from starlette.testclient import TestClient -from app.gateway.auth_middleware import AuthMiddleware, _is_public +from app.gateway.auth_middleware import AuthMiddleware, _is_public, _is_public_request # ── _is_public unit tests ───────────────────────────────────────────────── @@ -27,6 +27,15 @@ def test_public_paths(path: str): assert _is_public(path) is True +def test_public_share_read_request(): + assert _is_public_request("GET", "/api/shares/share-1") is True + assert _is_public_request("DELETE", "/api/shares/share-1") is False + assert _is_public_request("GET", "/api/shares/threads") is False + assert _is_public_request("POST", "/api/shares/threads/thread-1") is False + assert _is_public_request("GET", "/api/shares/threads/thread-1") is False + assert _is_public_request("GET", "/api/shares-anything") is False + + @pytest.mark.parametrize( "path", [ @@ -129,6 +138,26 @@ async def stream(): async def future(): return {"ok": True} + @app.get("/api/shares/share-1") + async def share_get(): + return {"ok": True} + + @app.delete("/api/shares/share-1") + async def share_delete(): + return {"ok": True} + + @app.get("/api/shares/threads") + async def share_threads_reserved(): + return {"ok": True} + + @app.post("/api/shares/threads/abc") + async def share_create(): + return {"ok": True} + + @app.get("/api/shares-anything") + async def shares_prefix_lookalike(): + return {"ok": True} + return app @@ -148,6 +177,31 @@ def test_public_auth_path_no_cookie(client): assert res.status_code == 200 +def test_public_share_path_no_cookie(client): + res = client.get("/api/shares/share-1") + assert res.status_code == 200 + + +def test_share_create_no_cookie_returns_401(client): + res = client.post("/api/shares/threads/abc") + assert res.status_code == 401 + + +def test_share_revoke_no_cookie_returns_401(client): + res = client.delete("/api/shares/share-1") + assert res.status_code == 401 + + +def test_share_threads_reserved_no_cookie_returns_401(client): + res = client.get("/api/shares/threads") + assert res.status_code == 401 + + +def test_share_prefix_lookalike_no_cookie_returns_401(client): + res = client.get("/api/shares-anything") + assert res.status_code == 401 + + def test_protected_auth_path_no_cookie(client): """/auth/me requires cookie even though it's under /api/v1/auth/.""" res = client.get("/api/v1/auth/me") diff --git a/backend/tests/test_shares_router.py b/backend/tests/test_shares_router.py new file mode 100644 index 0000000000..95f74bac2d --- /dev/null +++ b/backend/tests/test_shares_router.py @@ -0,0 +1,392 @@ +import asyncio +from types import SimpleNamespace +from uuid import UUID + +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient +from langchain_core.messages import AIMessage, HumanMessage +from langgraph.store.memory import InMemoryStore + +from app.gateway.auth.models import User +from app.gateway.routers import shares +from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore + +_USER_ID = UUID("11111111-1111-1111-1111-111111111111") +_OTHER_USER_ID = UUID("22222222-2222-2222-2222-222222222222") + + +def _make_user(user_id: UUID = _USER_ID) -> User: + return User( + email="share-test@example.com", + password_hash="x", + system_role="user", + id=user_id, + ) + + +class _ShareTestStore(InMemoryStore): + def __init__(self, *, supports_ttl: bool = True) -> None: + super().__init__() + self.supports_ttl = supports_ttl + self.put_ttls: list[float | None] = [] + + async def aput(self, *args, **kwargs): # type: ignore[no-untyped-def] + if args and args[0] == shares._SHARES_NS: + self.put_ttls.append(kwargs.get("ttl")) + return await super().aput(*args, **kwargs) + + +class _FakeCheckpointer: + def __init__(self) -> None: + self.checkpoints: dict[str, dict] = {} + + async def aget_tuple(self, config: dict): + thread_id = config["configurable"]["thread_id"] + checkpoint = self.checkpoints.get(thread_id) + if checkpoint is None: + return None + return SimpleNamespace(checkpoint=checkpoint) + + +def _build_share_app( + *, + owner_check_passes: bool = True, + store_supports_ttl: bool = True, + user_id: UUID = _USER_ID, +) -> tuple[TestClient, _ShareTestStore, _FakeCheckpointer]: + store = _ShareTestStore(supports_ttl=store_supports_ttl) + app = make_authed_test_app(user_factory=lambda: _make_user(user_id), owner_check_passes=owner_check_passes) + checkpointer = _FakeCheckpointer() + app.state.store = store + app.state.checkpointer = checkpointer + if owner_check_passes: + app.state.thread_store = MemoryThreadMetaStore(store) + app.include_router(shares.router) + return TestClient(app), store, checkpointer + + +def _seed_thread( + store: InMemoryStore, + checkpointer: _FakeCheckpointer, + thread_id: str, + *, + owner_user_id: UUID | None = _USER_ID, +) -> None: + async def _seed() -> None: + await MemoryThreadMetaStore(store).create( + thread_id, + metadata={}, + user_id=str(owner_user_id) if owner_user_id is not None else None, + ) + checkpointer.checkpoints[thread_id] = { + "channel_values": { + "title": "Share source", + "messages": [ + HumanMessage(content="Question", id="human-1"), + AIMessage( + content="Answer", + id="ai-1", + additional_kwargs={"private": "not public"}, + response_metadata={"model": "hidden"}, + ), + AIMessage( + content="", + id="tool-call-1", + tool_calls=[{"name": "search", "args": {}, "id": "call-1"}], + ), + HumanMessage(content="Follow-up", id="human-2"), + ], + }, + } + + asyncio.run(_seed()) + + +def test_create_share_snapshots_selected_messages_and_public_read() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + + assert response.status_code == 200, response.text + share_id = response.json()["share_id"] + + async def _assert_internal_grant_metadata() -> None: + item = await store.aget(shares._SHARES_NS, share_id) + assert item is not None + assert item.value["source_thread_id"] == "thread-share" + assert item.value["created_by_user_id"] == str(_USER_ID) + assert item.value["message_ids"] == ["human-1", "ai-1"] + assert item.value["visibility"] == "public_link" + assert item.value["granted_at"] == item.value["created_at"] + + asyncio.run(_assert_internal_grant_metadata()) + + public_response = client.get(f"/api/shares/{share_id}") + assert public_response.status_code == 200, public_response.text + body = public_response.json() + assert body["title"] == "Share source" + assert [message["id"] for message in body["messages"]] == ["human-1", "ai-1"] + assert [message["content"] for message in body["messages"]] == ["Question", "Answer"] + assert body["messages"][1] == {"type": "ai", "content": "Answer", "id": "ai-1"} + assert "response_metadata" not in body["messages"][1] + assert "additional_kwargs" not in body["messages"][1] + assert "source_thread_id" not in body + assert "created_by_user_id" not in body + assert "message_ids" not in body + assert "visibility" not in body + assert "granted_at" not in body + assert store.put_ttls == [shares._SHARE_TTL_MINUTES] + + +def test_create_share_keeps_intentional_empty_title() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"], "title": ""}, + ) + + assert response.status_code == 200, response.text + assert response.json()["title"] == "" + + +def test_create_share_rejects_unknown_message_id() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["missing-message"]}, + ) + + assert response.status_code == 400 + assert "missing-message" in response.json()["detail"] + + +def test_create_share_requires_selected_message_ids() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post("/api/shares/threads/thread-share", json={}) + + assert response.status_code == 422 + + +def test_create_share_rejects_non_shareable_message_id() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["tool-call-1"]}, + ) + + assert response.status_code == 400 + assert "not shareable" in response.json()["detail"] + + +def test_create_share_requires_thread_access() -> None: + client, store, checkpointer = _build_share_app(owner_check_passes=False) + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + + assert response.status_code == 404 + + +def test_create_share_requires_explicit_thread_owner() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "legacy-thread", owner_user_id=None) + + response = client.post( + "/api/shares/threads/legacy-thread", + json={"message_ids": ["human-1", "ai-1"]}, + ) + + assert response.status_code == 404 + + +def test_create_share_rejects_thread_owned_by_another_user() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "other-thread", owner_user_id=_OTHER_USER_ID) + + response = client.post( + "/api/shares/threads/other-thread", + json={"message_ids": ["human-1", "ai-1"]}, + ) + + assert response.status_code == 404 + + +def test_create_share_returns_503_without_checkpointer() -> None: + client, store, _checkpointer = _build_share_app() + _seed_thread(store, _checkpointer, "thread-share") + client.app.state.checkpointer = None + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + + assert response.status_code == 503 + + +def test_get_share_deletes_expired_snapshot_when_ttl_is_unavailable() -> None: + client, store, checkpointer = _build_share_app(store_supports_ttl=False) + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + assert response.status_code == 200, response.text + share_id = response.json()["share_id"] + + async def _expire_share() -> None: + item = await store.aget(shares._SHARES_NS, share_id) + assert item is not None + value = dict(item.value) + value["expires_at"] = "2000-01-01T00:00:00+00:00" + await store.aput(shares._SHARES_NS, share_id, value) + + asyncio.run(_expire_share()) + + public_response = client.get(f"/api/shares/{share_id}") + assert public_response.status_code == 404 + + async def _assert_deleted() -> None: + assert await store.aget(shares._SHARES_NS, share_id) is None + + asyncio.run(_assert_deleted()) + + +def test_revoke_share_hides_public_snapshot() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + assert response.status_code == 200, response.text + share_id = response.json()["share_id"] + + revoke_response = client.delete(f"/api/shares/{share_id}") + + assert revoke_response.status_code == 204, revoke_response.text + + async def _assert_revoked() -> None: + item = await store.aget(shares._SHARES_NS, share_id) + assert item is not None + assert item.value["visibility"] == "revoked" + assert item.value["revoked_at"] + + asyncio.run(_assert_revoked()) + + public_response = client.get(f"/api/shares/{share_id}") + assert public_response.status_code == 404 + + async def _assert_revocation_record_kept() -> None: + item = await store.aget(shares._SHARES_NS, share_id) + assert item is not None + assert item.value["visibility"] == "revoked" + + asyncio.run(_assert_revocation_record_kept()) + + +def test_revoke_share_requires_creator() -> None: + client, store, checkpointer = _build_share_app() + _seed_thread(store, checkpointer, "thread-share") + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + assert response.status_code == 200, response.text + share_id = response.json()["share_id"] + + other_client, _store, _checkpointer = _build_share_app(user_id=_OTHER_USER_ID) + other_client.app.state.store = store + other_client.app.state.checkpointer = checkpointer + other_client.app.state.thread_store = MemoryThreadMetaStore(store) + + revoke_response = other_client.delete(f"/api/shares/{share_id}") + + assert revoke_response.status_code == 404 + assert client.get(f"/api/shares/{share_id}").status_code == 200 + + +def test_create_share_cleans_expired_snapshots_beyond_first_batch_without_ttl() -> None: + client, store, checkpointer = _build_share_app(store_supports_ttl=False) + _seed_thread(store, checkpointer, "thread-share") + + async def _seed_expired_shares() -> None: + for index in range(shares._EXPIRED_SHARE_CLEANUP_BATCH_SIZE + 5): + await store.aput( + shares._SHARES_NS, + f"expired-share-{index:03d}", + { + "created_at": "2000-01-01T00:00:00+00:00", + "expires_at": "2000-01-02T00:00:00+00:00", + "messages": [], + }, + ) + + asyncio.run(_seed_expired_shares()) + + response = client.post( + "/api/shares/threads/thread-share", + json={"message_ids": ["human-1", "ai-1"]}, + ) + + assert response.status_code == 200, response.text + + async def _assert_expired_deleted() -> None: + items = await store.asearch(shares._SHARES_NS, limit=200, refresh_ttl=False) + assert not [item.key for item in items if item.key.startswith("expired-share-")] + + asyncio.run(_assert_expired_deleted()) + + +def test_get_share_normalizes_stored_messages() -> None: + client, store, _checkpointer = _build_share_app() + + async def _seed_share() -> None: + await store.aput( + shares._SHARES_NS, + "share-with-metadata", + { + "title": "Legacy share", + "created_at": "2026-05-28T00:00:00+00:00", + "messages": [ + { + "id": "ai-1", + "type": "ai", + "content": "Answer", + "response_metadata": {"model": "hidden"}, + "additional_kwargs": {"private": "not public"}, + }, + { + "id": "tool-call-1", + "type": "ai", + "content": "", + "tool_calls": [{"name": "search", "args": {}, "id": "call-1"}], + }, + ], + }, + ) + + asyncio.run(_seed_share()) + + response = client.get("/api/shares/share-with-metadata") + + assert response.status_code == 200, response.text + assert response.json()["messages"] == [{"type": "ai", "content": "Answer", "id": "ai-1"}] diff --git a/frontend/src/app/share/[share_id]/page.tsx b/frontend/src/app/share/[share_id]/page.tsx new file mode 100644 index 0000000000..b21eda76aa --- /dev/null +++ b/frontend/src/app/share/[share_id]/page.tsx @@ -0,0 +1,113 @@ +"use client"; + +import type { BaseStream } from "@langchain/langgraph-sdk/react"; +import { useParams } from "next/navigation"; +import { useEffect, useMemo, useState } from "react"; + +import { SidebarProvider } from "@/components/ui/sidebar"; +import { ArtifactsProvider } from "@/components/workspace/artifacts"; +import { MessageList } from "@/components/workspace/messages"; +import { getBackendBaseURL } from "@/core/config"; +import { SubtasksProvider } from "@/core/tasks/context"; +import type { AgentThreadState, ThreadShareResponse } from "@/core/threads"; +import { readErrorDetail } from "@/core/threads/api"; + +async function readShareLoadError(response: Response) { + const message = await readErrorDetail( + response, + response.status === 404 ? "Share not found" : "Failed to load share", + ); + return message.includes(`(${response.status})`) + ? message + : `${message} (${response.status})`; +} + +export default function SharePage() { + const { share_id: shareId } = useParams<{ share_id: string }>(); + const [share, setShare] = useState(null); + const [error, setError] = useState(null); + + useEffect(() => { + let cancelled = false; + + async function loadShare() { + try { + const response = await fetch( + `${getBackendBaseURL()}/api/shares/${encodeURIComponent(shareId)}`, + ); + if (!response.ok) { + throw new Error(await readShareLoadError(response)); + } + const data = (await response.json()) as ThreadShareResponse; + if (!cancelled) { + setShare(data); + } + } catch (err) { + if (!cancelled) { + setError(err instanceof Error ? err.message : "Failed to load share"); + } + } + } + + void loadShare(); + return () => { + cancelled = true; + }; + }, [shareId]); + + const thread = useMemo( + () => + ({ + messages: share?.messages ?? [], + values: { + title: share?.title ?? "", + messages: share?.messages ?? [], + artifacts: [], + }, + isLoading: false, + isThreadLoading: share === null && error === null, + getMessagesMetadata: () => [], + }) as unknown as BaseStream, + [error, share], + ); + + return ( + + + +
+
+
+
+
+ DeerFlow +
+ {share?.title && ( +

+ {share.title} +

+ )} +
+
+
+ {error ? ( +
+ {error} +
+ ) : ( +
+ +
+ )} +
+
+
+
+ ); +} diff --git a/frontend/src/components/workspace/messages/message-list.tsx b/frontend/src/components/workspace/messages/message-list.tsx index ca8672a3a6..b5e0852bf1 100644 --- a/frontend/src/components/workspace/messages/message-list.tsx +++ b/frontend/src/components/workspace/messages/message-list.tsx @@ -1,7 +1,8 @@ import type { Message } from "@langchain/langgraph-sdk"; import type { BaseStream } from "@langchain/langgraph-sdk/react"; -import { ChevronUpIcon, Loader2Icon } from "lucide-react"; +import { ChevronUpIcon, Loader2Icon, Share2Icon } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef } from "react"; +import { toast } from "sonner"; import { Conversation, @@ -31,11 +32,13 @@ import type { Subtask } from "@/core/tasks"; import { useUpdateSubtask } from "@/core/tasks/context"; import { parseSubtaskResult } from "@/core/tasks/subtask-result"; import type { AgentThreadState } from "@/core/threads"; +import { createThreadShare } from "@/core/threads/api"; import { cn } from "@/lib/utils"; import { ArtifactFileList } from "../artifacts/artifact-file-list"; import { CopyButton } from "../copy-button"; import { StreamingIndicator } from "../streaming-indicator"; +import { Tooltip } from "../tooltip"; import { MarkdownContent } from "./markdown-content"; import { MessageGroup } from "./message-group"; @@ -165,6 +168,7 @@ export function MessageList({ hasMoreHistory, loadMoreHistory, isHistoryLoading, + enableSharing = true, }: { className?: string; threadId: string; @@ -174,14 +178,30 @@ export function MessageList({ hasMoreHistory?: boolean; loadMoreHistory?: () => void; isHistoryLoading?: boolean; + enableSharing?: boolean; }) { const { t } = useI18n(); const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading); const updateSubtask = useUpdateSubtask(); const messages = thread.messages; - const groupedMessages = getMessageGroups(messages); + const groupedMessages = useMemo(() => getMessageGroups(messages), [messages]); const turnUsageMessagesByGroupIndex = getAssistantTurnUsageMessages(groupedMessages); + const previousHumanMessagesByGroupIndex = useMemo(() => { + const previousByIndex: Message[][] = []; + let previousHumanMessages: Message[] | null = null; + + groupedMessages.forEach((group, index) => { + previousByIndex[index] = previousHumanMessages ?? []; + if (group.type === "human") { + previousHumanMessages = group.messages; + } else if (group.type === "assistant") { + previousHumanMessages = null; + } + }); + + return previousByIndex; + }, [groupedMessages]); const tokenDebugSteps = useMemo( () => buildTokenDebugSteps(messages, t), [messages, t], @@ -196,21 +216,84 @@ export function MessageList({ [messages, thread.getMessagesMetadata, thread.isLoading], ); - const renderAssistantCopyButton = useCallback( - (messages: Message[], isStreaming: boolean) => { + const renderAssistantActions = useCallback( + ( + messages: Message[], + isStreaming: boolean, + previousMessages: Message[], + ) => { const clipboardData = getAssistantTurnCopyData(messages, { isStreaming }); if (!clipboardData) { return null; } + const shareMessageIds = [...previousMessages, ...messages] + .map((message) => message.id) + .filter((id): id is string => typeof id === "string" && id.length > 0); + return ( -
+
+ {enableSharing && ( + + + + )}
); }, - [], + [ + t.clipboard.failedToCopyToClipboard, + t.clipboard.linkCopied, + t.common.share, + t.conversation.noMessages, + enableSharing, + thread.values.title, + threadId, + ], ); const renderTokenUsage = useCallback( @@ -275,6 +358,8 @@ export function MessageList({ /> {groupedMessages.map((group, groupIndex) => { const turnUsageMessages = turnUsageMessagesByGroupIndex[groupIndex]; + const previousMessages = + previousHumanMessagesByGroupIndex[groupIndex] ?? []; if (group.type === "human" || group.type === "assistant") { return ( @@ -301,12 +386,13 @@ export function MessageList({ turnUsageMessages, })} {group.type === "assistant" && - renderAssistantCopyButton( + renderAssistantActions( group.messages, isAssistantMessageGroupStreaming( group.messages, streamingMessages, ), + previousMessages, )}
); diff --git a/frontend/src/core/threads/api.ts b/frontend/src/core/threads/api.ts index 1d1feb40f7..04161e68dd 100644 --- a/frontend/src/core/threads/api.ts +++ b/frontend/src/core/threads/api.ts @@ -1,7 +1,54 @@ import { fetch as fetchWithAuth } from "@/core/api/fetcher"; import { getBackendBaseURL } from "@/core/config"; -import type { ThreadTokenUsageResponse } from "./types"; +import type { + ThreadShareCreateResponse, + ThreadTokenUsageResponse, +} from "./types"; + +function formatErrorDetail(detail: unknown): string | null { + if (typeof detail === "string" && detail) { + return detail; + } + if (Array.isArray(detail)) { + const messages = detail + .map((item) => { + if (typeof item === "string") { + return item; + } + if (item && typeof item === "object" && "msg" in item) { + const message = (item as { msg?: unknown }).msg; + return typeof message === "string" && message ? message : null; + } + return null; + }) + .filter((message): message is string => Boolean(message)); + if (messages.length > 0) { + return messages.join("; "); + } + } + if (detail && typeof detail === "object") { + try { + return JSON.stringify(detail); + } catch { + return null; + } + } + return null; +} + +export async function readErrorDetail(response: Response, fallback: string) { + try { + const body = (await response.json()) as { detail?: unknown }; + const detail = formatErrorDetail(body.detail); + if (detail) { + return detail; + } + } catch { + // Ignore malformed error bodies and keep the stable fallback message. + } + return `${fallback} (${response.status})`; +} export async function fetchThreadTokenUsage( threadId: string, @@ -22,3 +69,31 @@ export async function fetchThreadTokenUsage( return (await response.json()) as ThreadTokenUsageResponse; } + +export async function createThreadShare({ + threadId, + messageIds, + title, +}: { + threadId: string; + messageIds: string[]; + title?: string; +}): Promise { + const response = await fetchWithAuth( + `${getBackendBaseURL()}/api/shares/threads/${encodeURIComponent(threadId)}`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + message_ids: messageIds, + title, + }), + }, + ); + + if (!response.ok) { + throw new Error(await readErrorDetail(response, "Failed to create share.")); + } + + return (await response.json()) as ThreadShareCreateResponse; +} diff --git a/frontend/src/core/threads/types.ts b/frontend/src/core/threads/types.ts index dafb073494..1a783c0e1a 100644 --- a/frontend/src/core/threads/types.ts +++ b/frontend/src/core/threads/types.ts @@ -45,3 +45,13 @@ export interface ThreadTokenUsageResponse { middleware: number; }; } + +export interface ThreadShareCreateResponse { + share_id: string; + title?: string | null; + created_at: string; +} + +export interface ThreadShareResponse extends ThreadShareCreateResponse { + messages: Message[]; +} diff --git a/frontend/tests/e2e/share.spec.ts b/frontend/tests/e2e/share.spec.ts new file mode 100644 index 0000000000..0df60a0f55 --- /dev/null +++ b/frontend/tests/e2e/share.spec.ts @@ -0,0 +1,62 @@ +import { expect, test } from "@playwright/test"; + +const SHARE_ID = "share-public-1"; + +test.describe("Public share page", () => { + test("renders a shared answer snapshot without workspace APIs", async ({ + page, + }) => { + await page.route(`**/api/shares/${SHARE_ID}`, (route) => + route.fulfill({ + status: 200, + contentType: "application/json", + body: JSON.stringify({ + share_id: SHARE_ID, + title: "Research summary", + created_at: "2026-01-01T00:00:00Z", + messages: [ + { + type: "human", + id: "human-1", + content: [{ type: "text", text: "Summarize the report" }], + }, + { + type: "ai", + id: "ai-1", + content: "The report highlights revenue growth.", + }, + ], + }), + }), + ); + + await page.goto(`/share/${SHARE_ID}`); + + await expect( + page.getByRole("heading", { name: "Research summary" }), + ).toBeVisible(); + await expect(page.getByText("Summarize the report")).toBeVisible({ + timeout: 15_000, + }); + await expect( + page.getByText("The report highlights revenue growth."), + ).toBeVisible(); + await expect(page.getByRole("button", { name: /share/i })).toHaveCount(0); + }); + + test("shows the public share load error", async ({ page }) => { + await page.route("**/api/shares/missing-share", (route) => + route.fulfill({ + status: 404, + contentType: "application/json", + body: JSON.stringify({ detail: "Share not found" }), + }), + ); + + await page.goto("/share/missing-share"); + + await expect(page.getByText("Share not found (404)")).toBeVisible({ + timeout: 15_000, + }); + }); +}); diff --git a/frontend/tests/unit/core/threads/api.test.ts b/frontend/tests/unit/core/threads/api.test.ts index 4d1268694b..9a53520c06 100644 --- a/frontend/tests/unit/core/threads/api.test.ts +++ b/frontend/tests/unit/core/threads/api.test.ts @@ -53,3 +53,91 @@ test("fetchThreadTokenUsage returns null for unavailable token usage", async () await expect(fetchThreadTokenUsage("thread-1")).resolves.toBeNull(); }); + +test("createThreadShare posts selected message ids", async () => { + fetchWithAuth.mockResolvedValue({ + ok: true, + json: async () => ({ + share_id: "share-1", + title: "Shared answer", + created_at: "2026-05-28T00:00:00+00:00", + }), + }); + + const { createThreadShare } = await import("@/core/threads/api"); + + await expect( + createThreadShare({ + threadId: "thread-1", + messageIds: ["human-1", "ai-1"], + title: "Shared answer", + }), + ).resolves.toMatchObject({ share_id: "share-1" }); + + expect(fetchWithAuth).toHaveBeenCalledWith( + expect.stringContaining("/api/shares/threads/thread-1"), + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + message_ids: ["human-1", "ai-1"], + title: "Shared answer", + }), + }, + ); +}); + +test("createThreadShare rejects with backend error detail", async () => { + fetchWithAuth.mockResolvedValue({ + ok: false, + status: 400, + json: async () => ({ detail: "Message IDs not found: missing-message" }), + }); + + const { createThreadShare } = await import("@/core/threads/api"); + + await expect( + createThreadShare({ + threadId: "thread-1", + messageIds: ["missing-message"], + }), + ).rejects.toThrow("Message IDs not found: missing-message"); +}); + +test("createThreadShare rejects with validation array details", async () => { + fetchWithAuth.mockResolvedValue({ + ok: false, + status: 422, + json: async () => ({ + detail: [ + { msg: "Field required", loc: ["body", "message_ids"] }, + { msg: "String should have at most 256 characters" }, + ], + }), + }); + + const { createThreadShare } = await import("@/core/threads/api"); + + await expect( + createThreadShare({ + threadId: "thread-1", + messageIds: [], + }), + ).rejects.toThrow( + "Field required; String should have at most 256 characters", + ); +}); + +test("readErrorDetail formats object details", async () => { + const { readErrorDetail } = await import("@/core/threads/api"); + + await expect( + readErrorDetail( + { + status: 503, + json: async () => ({ detail: { error: "Store not available" } }), + } as Response, + "Failed to load share", + ), + ).resolves.toBe('{"error":"Store not available"}'); +});