From d683889edce9dfcd653e3d5b2f1c97b22e1e6b2c Mon Sep 17 00:00:00 2001 From: zhongli <335302680@qq.com> Date: Thu, 28 May 2026 12:24:26 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dchannel=E4=B8=AD=E7=9A=84?= =?UTF-8?q?user=5Fid=E4=BC=A0=E9=80=92=E5=88=B0interceptor=E4=B8=AD?= =?UTF-8?q?=E7=9A=84bug,=20mcp=E5=8F=AF=E9=80=9A=E8=BF=87header=E4=BC=A0?= =?UTF-8?q?=E9=80=92user=5Fid=E5=88=B0mcp=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Cursor --- backend/app/channels/manager.py | 5 ++++- backend/app/gateway/services.py | 5 +++++ backend/packages/harness/deerflow/mcp/tools.py | 7 ++++++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 015f91e58d..6fed5cace2 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -604,7 +604,10 @@ def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str, self._default_session.get("context"), channel_layer.get("context"), user_layer.get("context"), - {"thread_id": thread_id}, + { + "thread_id": thread_id, + "user_id": msg.user_id, + }, ) # Custom agents are implemented as lead_agent + agent_name context. diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 63ac0c1bf6..2829c27575 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -151,6 +151,8 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An configurable.setdefault(key, context[key]) if isinstance(runtime_context, dict): runtime_context.setdefault(key, context[key]) + if "user_id" in context and isinstance(runtime_context, dict): + runtime_context.setdefault("user_id", context["user_id"]) def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None: @@ -166,6 +168,9 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request) if user_id is None: return + if getattr(user, "system_role", None) == "internal": + return + runtime_context = config.setdefault("context", {}) if isinstance(runtime_context, dict): runtime_context["user_id"] = str(user_id) diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index f38b70375d..3b83e706f4 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -137,7 +137,12 @@ async def call_with_persistent_session( from langchain_mcp_adapters.interceptors import MCPToolCallRequest async def base_handler(request: MCPToolCallRequest) -> Any: - return await session.call_tool(request.name, request.args) + # Preserve interceptor-injected headers for stdio MCP calls by + # forwarding them through MCP call meta. + call_kwargs: dict[str, Any] = {} + if request.headers: + call_kwargs["meta"] = {"headers": dict(request.headers)} + return await session.call_tool(request.name, request.args, **call_kwargs) handler = base_handler for interceptor in reversed(tool_interceptors): From 99df04142f755328de48c0810416847a83148c61 Mon Sep 17 00:00:00 2001 From: zhongli <335302680@qq.com> Date: Fri, 29 May 2026 11:18:01 +0800 Subject: [PATCH 2/3] fix(channel,mcp,gateway): normalize channel user_id and add regression tests Normalize external channel user ids into filesystem-safe runtime context while preserving raw channel_user_id, and document gateway user_id propagation semantics. Add regression coverage for channel user_id context mapping, gateway user_id precedence/internal-role behavior, and MCP interceptor header forwarding via meta.headers. Co-authored-by: Cursor --- backend/app/channels/manager.py | 14 ++- backend/app/gateway/services.py | 9 +- .../packages/harness/deerflow/config/paths.py | 19 +++++ backend/tests/test_channels.py | 35 ++++++++ backend/tests/test_gateway_services.py | 43 ++++++++++ backend/tests/test_mcp_session_pool.py | 85 +++++++++++++++++++ backend/tests/test_paths_user_isolation.py | 34 ++++++++ 7 files changed, 234 insertions(+), 5 deletions(-) diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 6fed5cace2..f1a4db8fa0 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -19,6 +19,7 @@ from app.channels.store import ChannelStore from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token from app.gateway.internal_auth import create_internal_auth_headers +from deerflow.config.paths import make_safe_user_id from deerflow.runtime.user_context import get_effective_user_id logger = logging.getLogger(__name__) @@ -599,15 +600,20 @@ def _resolve_run_params(self, msg: InboundMessage, thread_id: str) -> tuple[str, configurable["checkpoint_ns"] = "" configurable["thread_id"] = thread_id + # ``user_id`` drives user-scoped filesystem buckets that only accept + # ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value + # under ``channel_user_id`` for platform-facing lookups. + run_context_identity: dict[str, Any] = {"thread_id": thread_id} + if msg.user_id: + run_context_identity["user_id"] = make_safe_user_id(msg.user_id) + run_context_identity["channel_user_id"] = msg.user_id + run_context = _merge_dicts( DEFAULT_RUN_CONTEXT, self._default_session.get("context"), channel_layer.get("context"), user_layer.get("context"), - { - "thread_id": thread_id, - "user_id": msg.user_id, - }, + run_context_identity, ) # Custom agents are implemented as lead_agent + agent_name context. diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 2829c27575..1be3f90b69 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -140,7 +140,14 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An """Merge whitelisted keys from ``body.context`` into both ``config['configurable']`` and ``config['context']`` so they are visible to legacy configurable readers and to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool — - see issue #2677).""" + see issue #2677). + + ``user_id`` is intentionally propagated into ``config['context']`` in addition to + the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in + ``body.context`` keep it on ``ToolRuntime.context``. It is merged with + ``setdefault`` so a server-authenticated id stamped by + :func:`inject_authenticated_user_context` always wins over the client-supplied one. + """ if not context: return configurable = config.setdefault("configurable", {}) diff --git a/backend/packages/harness/deerflow/config/paths.py b/backend/packages/harness/deerflow/config/paths.py index c068390401..77035ad8b8 100644 --- a/backend/packages/harness/deerflow/config/paths.py +++ b/backend/packages/harness/deerflow/config/paths.py @@ -1,3 +1,4 @@ +import hashlib import os import re import shutil @@ -10,6 +11,7 @@ _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") +_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]") def _default_local_base_dir() -> Path: @@ -31,6 +33,23 @@ def _validate_user_id(user_id: str) -> str: return user_id +def make_safe_user_id(raw: str) -> str: + """Normalize an external identity into the user-id charset (``[A-Za-z0-9_-]``). + + IM channel ids (Feishu/Slack/Telegram) may contain characters that + :func:`_validate_user_id` rejects. Already-safe ids pass through unchanged; + lossy ones get a short digest suffix so two distinct inputs never share a + storage bucket. + """ + if not raw: + raise ValueError("user_id must be a non-empty string.") + sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw) + if sanitized == raw: + return raw + digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:8] + return f"{sanitized}-{digest}" + + def _join_host_path(base: str, *parts: str) -> str: """Join host filesystem path segments while preserving native style. diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 61a402defc..cf21500c4f 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -1591,6 +1591,41 @@ async def capture(msg): _run(go()) +class TestResolveRunParamsUserId: + """Regression for PR #3294: channel identity must reach ``run_context`` + while staying safe for user-scoped filesystem buckets. + """ + + def _manager(self): + from app.channels.manager import ChannelManager + + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + return ChannelManager(bus=bus, store=store) + + def test_safe_user_id_is_passed_through(self): + manager = self._manager() + msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == "123456" + assert run_context["channel_user_id"] == "123456" + + def test_unsafe_user_id_is_normalized_but_raw_preserved(self): + from deerflow.config.paths import make_safe_user_id + + manager = self._manager() + raw = "user@example.com" + msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == make_safe_user_id(raw) + assert run_context["user_id"] != raw + assert run_context["channel_user_id"] == raw + + # --------------------------------------------------------------------------- # ChannelService tests # --------------------------------------------------------------------------- diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index 2ccd372bf7..d62ed9371d 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -431,6 +431,49 @@ def test_inject_authenticated_user_context_overrides_client_user_id(): assert config["context"]["user_id"] == "auth-user-42" +def test_merge_run_context_overrides_propagates_user_id(): + """Regression for PR #3294: ``user_id`` from ``body.context`` must land in + ``config['context']`` so non-web callers (e.g. IM channels) keep their identity + on ``ToolRuntime.context``. + """ + from app.gateway.services import build_run_config, merge_run_context_overrides + + config = build_run_config("thread-1", None, None) + merge_run_context_overrides(config, {"user_id": "channel-user-7"}) + + assert config["context"]["user_id"] == "channel-user-7" + + +def test_merge_run_context_overrides_does_not_clobber_existing_user_id(): + """``merge_run_context_overrides`` must not override an already-stamped + authenticated ``context.user_id`` with the client-supplied value. + """ + from app.gateway.services import build_run_config, merge_run_context_overrides + + config = build_run_config("thread-1", {"context": {"user_id": "auth-user-42"}}, None) + merge_run_context_overrides(config, {"user_id": "spoofed-client"}) + + assert config["context"]["user_id"] == "auth-user-42" + + +def test_inject_authenticated_user_context_skips_internal_role(): + """Regression for PR #3294: internal system-role callers must not overwrite an + already-present ``context.user_id`` (e.g. a channel-supplied identity), so the + real end user keeps owning the per-user storage bucket. + """ + from types import SimpleNamespace + + from app.gateway.services import build_run_config, inject_authenticated_user_context + + config = build_run_config("thread-1", None, None) + config["context"] = {"user_id": "channel-user-7"} + request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="internal-bot", system_role="internal"))) + + inject_authenticated_user_context(config, request) + + assert config["context"]["user_id"] == "channel-user-7" + + # --------------------------------------------------------------------------- # build_run_config — context / configurable precedence (LangGraph >= 0.6.0) # --------------------------------------------------------------------------- diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py index 1e2ce7adc3..30d32b7740 100644 --- a/backend/tests/test_mcp_session_pool.py +++ b/backend/tests/test_mcp_session_pool.py @@ -256,6 +256,91 @@ class Args(BaseModel): mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"}) +@pytest.mark.asyncio +async def test_session_pool_tool_forwards_interceptor_headers(): + """Regression for PR #3294: when an interceptor sets ``request.headers``, the + pooled stdio call must forward them via ``meta={"headers": ...}`` so downstream + MCP servers can read auth/context headers. + """ + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="srv_act", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + async def header_interceptor(request, handler): + return await handler(request.override(headers={"X-User-Id": "u-42"})) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool( + original_tool, + "srv", + {"transport": "stdio", "command": "x", "args": []}, + tool_interceptors=[header_interceptor], + ) + await wrapped.coroutine(runtime=None, x=1) + + mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}, meta={"headers": {"X-User-Id": "u-42"}}) + + +@pytest.mark.asyncio +async def test_session_pool_tool_no_headers_omits_meta(): + """When no interceptor sets headers, the pooled call must not pass a ``meta`` + kwarg (falls back to the plain two-argument ``call_tool``). + """ + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="srv_act", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + async def passthrough_interceptor(request, handler): + return await handler(request) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool( + original_tool, + "srv", + {"transport": "stdio", "command": "x", "args": []}, + tool_interceptors=[passthrough_interceptor], + ) + await wrapped.coroutine(runtime=None, x=1) + + mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}) + + @pytest.mark.asyncio async def test_session_pool_tool_extracts_thread_id(): """Thread ID is extracted from runtime.config when not in context.""" diff --git a/backend/tests/test_paths_user_isolation.py b/backend/tests/test_paths_user_isolation.py index d5c0f540f1..9f9ade045c 100644 --- a/backend/tests/test_paths_user_isolation.py +++ b/backend/tests/test_paths_user_isolation.py @@ -30,6 +30,40 @@ def test_rejects_empty(self, paths: Paths): paths.user_dir("") +class TestMakeSafeUserId: + def test_already_safe_id_is_unchanged(self): + from deerflow.config.paths import make_safe_user_id + + assert make_safe_user_id("ou_abc-123") == "ou_abc-123" + assert make_safe_user_id("123456") == "123456" + + def test_unsafe_chars_are_sanitized_with_stable_suffix(self): + from deerflow.config.paths import make_safe_user_id + + result = make_safe_user_id("user@example.com") + # Sanitized prefix plus a stable digest of the original. + assert result.startswith("user-example-com-") + assert make_safe_user_id("user@example.com") == result + + def test_sanitized_id_passes_validation(self, paths: Paths): + from deerflow.config.paths import make_safe_user_id + + safe = make_safe_user_id("用户/../etc") + # Must be usable as a filesystem-scoped bucket without raising. + assert paths.user_dir(safe) == paths.base_dir / "users" / safe + + def test_distinct_unsafe_ids_do_not_collide(self): + from deerflow.config.paths import make_safe_user_id + + assert make_safe_user_id("a.b") != make_safe_user_id("a/b") + + def test_empty_id_rejected(self): + from deerflow.config.paths import make_safe_user_id + + with pytest.raises(ValueError, match="non-empty"): + make_safe_user_id("") + + class TestUserDir: def test_user_dir(self, paths: Paths): assert paths.user_dir("alice") == paths.base_dir / "users" / "alice" From 8f49b5cbac820534aa68772bd4b4c00e731f5a18 Mon Sep 17 00:00:00 2001 From: zhongli <335302680@qq.com> Date: Mon, 1 Jun 2026 10:26:07 +0800 Subject: [PATCH 3/3] fix(auth,mcp): harden user id normalization and header handling Increase sanitized user-id digest suffix to 16 hex chars, replace internal system role magic string with a shared constant, and harden MCP header forwarding with Mapping type checks. Add regression tests for empty channel user_id handling, unsupported header types, and updated digest length behavior. Co-authored-by: Cursor --- backend/app/gateway/internal_auth.py | 3 +- backend/app/gateway/services.py | 3 +- .../packages/harness/deerflow/config/paths.py | 3 +- .../packages/harness/deerflow/mcp/tools.py | 6 ++- backend/tests/test_channels.py | 10 +++++ backend/tests/test_mcp_session_pool.py | 45 +++++++++++++++++++ backend/tests/test_paths_user_isolation.py | 1 + 7 files changed, 67 insertions(+), 4 deletions(-) diff --git a/backend/app/gateway/internal_auth.py b/backend/app/gateway/internal_auth.py index 51ed89a99e..3a00a9662b 100644 --- a/backend/app/gateway/internal_auth.py +++ b/backend/app/gateway/internal_auth.py @@ -10,6 +10,7 @@ INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" +INTERNAL_SYSTEM_ROLE = "internal" def _load_internal_auth_token() -> str: @@ -34,4 +35,4 @@ def is_valid_internal_auth_token(token: str | None) -> bool: def get_internal_user(): """Return the synthetic user used for trusted internal channel calls.""" - return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal") + return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 1be3f90b69..2c5c01e61b 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -19,6 +19,7 @@ from langchain_core.messages.utils import convert_to_messages from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge +from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE from app.gateway.utils import sanitize_log_param from deerflow.config.app_config import get_app_config from deerflow.runtime import ( @@ -175,7 +176,7 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request) if user_id is None: return - if getattr(user, "system_role", None) == "internal": + if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE: return runtime_context = config.setdefault("context", {}) diff --git a/backend/packages/harness/deerflow/config/paths.py b/backend/packages/harness/deerflow/config/paths.py index 77035ad8b8..f01959657e 100644 --- a/backend/packages/harness/deerflow/config/paths.py +++ b/backend/packages/harness/deerflow/config/paths.py @@ -12,6 +12,7 @@ _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]") +_SAFE_USER_ID_DIGEST_HEX_LEN = 16 def _default_local_base_dir() -> Path: @@ -46,7 +47,7 @@ def make_safe_user_id(raw: str) -> str: sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw) if sanitized == raw: return raw - digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:8] + digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN] return f"{sanitized}-{digest}" diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index 3b83e706f4..e425efe0c2 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from collections.abc import Mapping from typing import Any from langchain_core.tools import BaseTool, StructuredTool @@ -141,7 +142,10 @@ async def base_handler(request: MCPToolCallRequest) -> Any: # forwarding them through MCP call meta. call_kwargs: dict[str, Any] = {} if request.headers: - call_kwargs["meta"] = {"headers": dict(request.headers)} + if isinstance(request.headers, Mapping): + call_kwargs["meta"] = {"headers": dict(request.headers)} + else: + logger.warning("Ignoring MCP interceptor headers with unsupported type: %s", type(request.headers).__name__) return await session.call_tool(request.name, request.args, **call_kwargs) handler = base_handler diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index cf21500c4f..aed0b95cc8 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -1625,6 +1625,16 @@ def test_unsafe_user_id_is_normalized_but_raw_preserved(self): assert run_context["user_id"] != raw assert run_context["channel_user_id"] == raw + @pytest.mark.parametrize("raw_user_id", ["", None]) + def test_empty_or_none_user_id_is_not_injected(self, raw_user_id): + manager = self._manager() + msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert "user_id" not in run_context + assert "channel_user_id" not in run_context + # --------------------------------------------------------------------------- # ChannelService tests diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py index 30d32b7740..40d02e61c1 100644 --- a/backend/tests/test_mcp_session_pool.py +++ b/backend/tests/test_mcp_session_pool.py @@ -341,6 +341,51 @@ async def passthrough_interceptor(request, handler): mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}) +@pytest.mark.asyncio +async def test_session_pool_tool_ignores_unsupported_header_type(caplog): + """Defensive path: non-mapping truthy headers should be ignored safely.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + class TruthyHeaders: + def __bool__(self) -> bool: + return True + + original_tool = StructuredTool( + name="srv_act", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + async def invalid_header_interceptor(request, handler): + return await handler(request.override(headers=TruthyHeaders())) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool( + original_tool, + "srv", + {"transport": "stdio", "command": "x", "args": []}, + tool_interceptors=[invalid_header_interceptor], + ) + await wrapped.coroutine(runtime=None, x=1) + + mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}) + assert "unsupported type" in caplog.text + + @pytest.mark.asyncio async def test_session_pool_tool_extracts_thread_id(): """Thread ID is extracted from runtime.config when not in context.""" diff --git a/backend/tests/test_paths_user_isolation.py b/backend/tests/test_paths_user_isolation.py index 9f9ade045c..692c526ed2 100644 --- a/backend/tests/test_paths_user_isolation.py +++ b/backend/tests/test_paths_user_isolation.py @@ -43,6 +43,7 @@ def test_unsafe_chars_are_sanitized_with_stable_suffix(self): result = make_safe_user_id("user@example.com") # Sanitized prefix plus a stable digest of the original. assert result.startswith("user-example-com-") + assert len(result.rsplit("-", 1)[1]) == 16 assert make_safe_user_id("user@example.com") == result def test_sanitized_id_passes_validation(self, paths: Paths):