Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend/app/channels/feishu.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def down_load():

virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"

sandbox_provider = None
sandbox_id = None
try:
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
Expand All @@ -390,6 +392,12 @@ def down_load():
except Exception:
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
return f"Failed to obtain the [{type}]"
finally:
if sandbox_provider is not None and sandbox_id is not None:
try:
sandbox_provider.release(sandbox_id)
except Exception:
logger.warning("[Feishu] failed to release sandbox %s after file sync", sandbox_id, exc_info=True)
Comment thread
BXL1015 marked this conversation as resolved.

logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
return virtual_path
Expand Down
36 changes: 28 additions & 8 deletions backend/app/gateway/routers/uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))


def _release_sandbox_after_upload_sync(sandbox_provider: SandboxProvider, sandbox_id: str | None) -> None:
if sandbox_id is None:
return
try:
sandbox_provider.release(sandbox_id)
except Exception:
logger.warning("Failed to release sandbox %s after upload sync", sandbox_id, exc_info=True)


def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
uploads_cfg = getattr(app_config, "uploads", None)
Expand Down Expand Up @@ -220,12 +229,18 @@ async def upload_files(
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None
sandbox_id = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
if sandbox is None:
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
auto_convert_documents = _auto_convert_documents_enabled(config)
try:
auto_convert_documents = _auto_convert_documents_enabled(config)
except Exception:
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise
Comment thread
BXL1015 marked this conversation as resolved.
Outdated

for file in files:
if not file.filename:
Expand Down Expand Up @@ -285,6 +300,7 @@ async def upload_files(

except HTTPException as e:
_cleanup_uploaded_paths(written_paths)
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise e
except UnsafeUploadPathError as e:
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
Expand All @@ -293,6 +309,7 @@ async def upload_files(
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
_cleanup_uploaded_paths(written_paths)
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")

# Uploaded files are created with 0o600 permissions (owner read/write only).
Expand All @@ -302,13 +319,16 @@ async def upload_files(
# directory is bind-mounted into the container or synced via
# sandbox.update_file. Always add group/other read bits so every sandbox
# configuration can read the uploaded content.
for file_path in written_paths:
_make_file_sandbox_readable(file_path)

if sync_to_sandbox:
for file_path, virtual_path in sandbox_sync_targets:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
try:
for file_path in written_paths:
_make_file_sandbox_readable(file_path)

if sync_to_sandbox:
for file_path, virtual_path in sandbox_sync_targets:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
finally:
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)

message = f"Successfully uploaded {len(uploaded_files)} file(s)"
if skipped_files:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self):
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id
self._thread_locks: dict[str, threading.Lock] = {} # thread_id -> in-process lock
self._last_activity: dict[str, float] = {} # sandbox_id -> last activity timestamp
self._lease_counts: dict[str, int] = {} # sandbox_id -> active acquire() lease count
# Warm pool: released sandboxes whose containers are still running.
# Maps sandbox_id -> (SandboxInfo, release_timestamp).
# Containers here can be reclaimed quickly (no cold-start) or destroyed
Expand Down Expand Up @@ -459,6 +460,14 @@ def _sandbox_id_for_thread(self, thread_id: str | None) -> str:
"""Return deterministic IDs for thread sandboxes and random IDs otherwise."""
return self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]

def _lease_counts_map(self) -> dict[str, int]:
"""Return the active lease map, creating it for test-constructed providers."""
lease_counts = getattr(self, "_lease_counts", None)
if lease_counts is None:
lease_counts = {}
self._lease_counts = lease_counts
return lease_counts

def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool = False) -> str | None:
"""Reuse an active in-process sandbox for a thread if one is still tracked."""
if thread_id is None:
Expand All @@ -472,6 +481,8 @@ def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool =
if existing_id in self._sandboxes:
suffix = " (post-lock check)" if post_lock else ""
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
lease_counts = self._lease_counts_map()
lease_counts[existing_id] = lease_counts.get(existing_id, 0) + 1
self._last_activity[existing_id] = time.time()
return existing_id
Comment thread
BXL1015 marked this conversation as resolved.

Expand All @@ -491,6 +502,7 @@ def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *,
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._lease_counts_map()[sandbox_id] = 1
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id

Expand All @@ -508,6 +520,7 @@ def _register_discovered_sandbox(self, thread_id: str, info: SandboxInfo) -> str
with self._lock:
self._sandboxes[info.sandbox_id] = sandbox
self._sandbox_infos[info.sandbox_id] = info
self._lease_counts_map()[info.sandbox_id] = 1
self._last_activity[info.sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = info.sandbox_id

Expand All @@ -520,6 +533,7 @@ def _register_created_sandbox(self, thread_id: str | None, sandbox_id: str, info
with self._lock:
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._lease_counts_map()[sandbox_id] = 1
self._last_activity[sandbox_id] = time.time()
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
Expand Down Expand Up @@ -797,6 +811,15 @@ def release(self, sandbox_id: str) -> None:
thread_ids_to_remove: list[str] = []

with self._lock:
lease_counts = self._lease_counts_map()
active_leases = lease_counts.get(sandbox_id, 1 if sandbox_id in self._sandboxes else 0)
if active_leases > 1:
lease_counts[sandbox_id] = active_leases - 1
self._last_activity[sandbox_id] = time.time()
logger.info(f"Released sandbox lease {sandbox_id} ({active_leases - 1} active lease(s) remain)")
return

lease_counts.pop(sandbox_id, None)
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
Expand Down Expand Up @@ -824,6 +847,7 @@ def destroy(self, sandbox_id: str) -> None:
with self._lock:
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
self._lease_counts_map().pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
Expand Down
40 changes: 40 additions & 0 deletions backend/tests/test_aio_sandbox_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _make_provider(tmp_path):
provider = aio_mod.AioSandboxProvider.__new__(aio_mod.AioSandboxProvider)
provider._config = {}
provider._sandboxes = {}
provider._lease_counts = {}
provider._lock = MagicMock()
provider._idle_checker_stop = MagicMock()
return provider
Expand Down Expand Up @@ -141,6 +142,45 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
assert unlock_calls == []


def test_reused_active_sandbox_requires_matching_releases(monkeypatch):
"""A temporary acquire must not release another caller's active sandbox lease."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(None)
provider._config = {"replicas": 3}
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lease_counts = {}
provider._lock = aio_mod.threading.Lock()
provider._backend = SimpleNamespace(
create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-lease", sandbox_url="http://sandbox")),
destroy=MagicMock(),
discover=MagicMock(return_value=None),
)
monkeypatch.setattr(provider, "_get_extra_mounts", lambda _thread_id: [])
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda *_args, **_kwargs: True)

first_id = provider.acquire("thread-lease")
second_id = provider.acquire("thread-lease")

assert first_id == second_id
assert provider._lease_counts[first_id] == 2

provider.release(first_id)

assert first_id in provider._sandboxes
assert first_id not in provider._warm_pool
assert provider._lease_counts[first_id] == 1

provider.release(first_id)

assert first_id not in provider._sandboxes
assert first_id in provider._warm_pool
assert first_id not in provider._lease_counts


@pytest.mark.anyio
async def test_acquire_async_uses_async_readiness_polling(monkeypatch):
"""AioSandboxProvider async creation must not use sync readiness polling."""
Expand Down
91 changes: 91 additions & 0 deletions backend/tests/test_feishu_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
Expand All @@ -17,6 +19,24 @@ def _run(coro):
loop.close()


class _FakeGetMessageResourceRequest:
@classmethod
def builder(cls):
return cls()

def message_id(self, _message_id):
return self

def file_key(self, _file_key):
return self

def type(self, _type):
return self

def build(self):
return object()


def test_feishu_on_message_plain_text():
bus = MessageBus()
config = {"app_id": "test", "app_secret": "test"}
Expand Down Expand Up @@ -103,6 +123,77 @@ async def go():
_run(go())


def test_feishu_receive_single_file_releases_sandbox_after_sync(tmp_path, monkeypatch):
async def go():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._GetMessageResourceRequest = _FakeGetMessageResourceRequest

response = MagicMock()
response.success.return_value = True
response.file = BytesIO(b"hello uploads")
response.file_name = "note.txt"

channel._api_client = MagicMock()
channel._api_client.im.v1.message_resource.get.return_value = response

paths = SimpleNamespace(
ensure_thread_dirs=MagicMock(),
sandbox_uploads_dir=MagicMock(return_value=tmp_path),
)
sandbox = MagicMock()
provider = MagicMock()
provider.acquire.return_value = "aio-1"
provider.get.return_value = sandbox

monkeypatch.setattr("app.channels.feishu.get_paths", lambda: paths)
monkeypatch.setattr("app.channels.feishu.get_sandbox_provider", lambda: provider)

result = await channel._receive_single_file("message-1", "file-key", "file", "thread-a")

assert result == "/mnt/user-data/uploads/note.txt"
assert (tmp_path / "note.txt").read_bytes() == b"hello uploads"
sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/note.txt", b"hello uploads")
provider.release.assert_called_once_with("aio-1")

_run(go())


def test_feishu_receive_single_file_releases_sandbox_when_sync_fails(tmp_path, monkeypatch):
async def go():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._GetMessageResourceRequest = _FakeGetMessageResourceRequest

response = MagicMock()
response.success.return_value = True
response.file = BytesIO(b"hello uploads")
response.file_name = "note.txt"

channel._api_client = MagicMock()
channel._api_client.im.v1.message_resource.get.return_value = response

paths = SimpleNamespace(
ensure_thread_dirs=MagicMock(),
sandbox_uploads_dir=MagicMock(return_value=tmp_path),
)
sandbox = MagicMock()
sandbox.update_file.side_effect = RuntimeError("sync failed")
provider = MagicMock()
provider.acquire.return_value = "aio-1"
provider.get.return_value = sandbox

monkeypatch.setattr("app.channels.feishu.get_paths", lambda: paths)
monkeypatch.setattr("app.channels.feishu.get_sandbox_provider", lambda: provider)

result = await channel._receive_single_file("message-1", "file-key", "file", "thread-a")

assert result == "Failed to obtain the [file]"
provider.release.assert_called_once_with("aio-1")

_run(go())


def test_feishu_on_message_extracts_image_and_file_keys():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
Expand Down
1 change: 1 addition & 0 deletions backend/tests/test_sandbox_orphan_reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _make_provider_for_reconciliation():
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
provider._lease_counts = {}
provider._warm_pool = {}
provider._shutdown_called = False
provider._idle_checker_stop = threading.Event()
Expand Down
Loading