Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend/app/channels/feishu.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def down_load():

virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"

sandbox_provider = None
sandbox_id = None
try:
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
Expand All @@ -390,6 +392,12 @@ def down_load():
except Exception:
logger.exception("[Feishu] failed to sync resource into non-local sandbox: %s", virtual_path)
return f"Failed to obtain the [{type}]"
finally:
if sandbox_provider is not None and sandbox_id is not None and sandbox_id != "local":
try:
sandbox_provider.release(sandbox_id)
except Exception:
logger.warning("[Feishu] failed to release sandbox %s after file sync", sandbox_id, exc_info=True)
Comment thread
BXL1015 marked this conversation as resolved.

logger.info("[Feishu] downloaded resource mapped: file_key=%s -> %s", file_key, virtual_path)
return virtual_path
Expand Down
30 changes: 23 additions & 7 deletions backend/app/gateway/routers/uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))


def _release_sandbox_after_upload_sync(sandbox_provider: SandboxProvider, sandbox_id: str | None) -> None:
if sandbox_id is None:
return
try:
sandbox_provider.release(sandbox_id)
except Exception:
logger.warning("Failed to release sandbox %s after upload sync", sandbox_id, exc_info=True)


def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
uploads_cfg = getattr(app_config, "uploads", None)
Expand Down Expand Up @@ -220,10 +229,12 @@ async def upload_files(
sandbox_provider = get_sandbox_provider()
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None
sandbox_id = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
if sandbox is None:
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
auto_convert_documents = _auto_convert_documents_enabled(config)

Expand Down Expand Up @@ -285,6 +296,7 @@ async def upload_files(

except HTTPException as e:
_cleanup_uploaded_paths(written_paths)
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise e
except UnsafeUploadPathError as e:
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
Expand All @@ -293,6 +305,7 @@ async def upload_files(
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
_cleanup_uploaded_paths(written_paths)
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")

# Uploaded files are created with 0o600 permissions (owner read/write only).
Expand All @@ -302,13 +315,16 @@ async def upload_files(
# directory is bind-mounted into the container or synced via
# sandbox.update_file. Always add group/other read bits so every sandbox
# configuration can read the uploaded content.
for file_path in written_paths:
_make_file_sandbox_readable(file_path)

if sync_to_sandbox:
for file_path, virtual_path in sandbox_sync_targets:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
try:
for file_path in written_paths:
_make_file_sandbox_readable(file_path)

if sync_to_sandbox:
for file_path, virtual_path in sandbox_sync_targets:
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, file_path.read_bytes())
finally:
_release_sandbox_after_upload_sync(sandbox_provider, sandbox_id)

message = f"Successfully uploaded {len(uploaded_files)} file(s)"
if skipped_files:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self):
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id
self._thread_locks: dict[str, threading.Lock] = {} # thread_id -> in-process lock
self._last_activity: dict[str, float] = {} # sandbox_id -> last activity timestamp
self._lease_counts: dict[str, int] = {} # sandbox_id -> active acquire() lease count
# Warm pool: released sandboxes whose containers are still running.
# Maps sandbox_id -> (SandboxInfo, release_timestamp).
# Containers here can be reclaimed quickly (no cold-start) or destroyed
Expand Down Expand Up @@ -459,6 +460,14 @@ def _sandbox_id_for_thread(self, thread_id: str | None) -> str:
"""Return deterministic IDs for thread sandboxes and random IDs otherwise."""
return self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]

def _lease_counts_map(self) -> dict[str, int]:
"""Return the active lease map, creating it for test-constructed providers."""
lease_counts = getattr(self, "_lease_counts", None)
if lease_counts is None:
lease_counts = {}
self._lease_counts = lease_counts
return lease_counts

def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool = False) -> str | None:
"""Reuse an active in-process sandbox for a thread if one is still tracked."""
if thread_id is None:
Expand All @@ -472,6 +481,8 @@ def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool =
if existing_id in self._sandboxes:
suffix = " (post-lock check)" if post_lock else ""
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
lease_counts = self._lease_counts_map()
lease_counts[existing_id] = lease_counts.get(existing_id, 0) + 1
self._last_activity[existing_id] = time.time()
return existing_id
Comment thread
BXL1015 marked this conversation as resolved.

Expand All @@ -491,6 +502,7 @@ def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *,
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._lease_counts_map()[sandbox_id] = 1
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id

Expand All @@ -508,6 +520,7 @@ def _register_discovered_sandbox(self, thread_id: str, info: SandboxInfo) -> str
with self._lock:
self._sandboxes[info.sandbox_id] = sandbox
self._sandbox_infos[info.sandbox_id] = info
self._lease_counts_map()[info.sandbox_id] = 1
self._last_activity[info.sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = info.sandbox_id

Expand All @@ -520,6 +533,7 @@ def _register_created_sandbox(self, thread_id: str | None, sandbox_id: str, info
with self._lock:
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._lease_counts_map()[sandbox_id] = 1
self._last_activity[sandbox_id] = time.time()
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
Expand Down Expand Up @@ -797,6 +811,15 @@ def release(self, sandbox_id: str) -> None:
thread_ids_to_remove: list[str] = []

with self._lock:
lease_counts = self._lease_counts_map()
active_leases = lease_counts.get(sandbox_id, 1 if sandbox_id in self._sandboxes else 0)
if active_leases > 1:
lease_counts[sandbox_id] = active_leases - 1
self._last_activity[sandbox_id] = time.time()
logger.info(f"Released sandbox lease {sandbox_id} ({active_leases - 1} active lease(s) remain)")
return

lease_counts.pop(sandbox_id, None)
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
Expand Down Expand Up @@ -824,6 +847,7 @@ def destroy(self, sandbox_id: str) -> None:
with self._lock:
self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
self._lease_counts_map().pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
Expand Down
40 changes: 40 additions & 0 deletions backend/tests/test_aio_sandbox_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _make_provider(tmp_path):
provider = aio_mod.AioSandboxProvider.__new__(aio_mod.AioSandboxProvider)
provider._config = {}
provider._sandboxes = {}
provider._lease_counts = {}
provider._lock = MagicMock()
provider._idle_checker_stop = MagicMock()
return provider
Expand Down Expand Up @@ -141,6 +142,45 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
assert unlock_calls == []


def test_reused_active_sandbox_requires_matching_releases(monkeypatch):
"""A temporary acquire must not release another caller's active sandbox lease."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(None)
provider._config = {"replicas": 3}
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lease_counts = {}
provider._lock = aio_mod.threading.Lock()
provider._backend = SimpleNamespace(
create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-lease", sandbox_url="http://sandbox")),
destroy=MagicMock(),
discover=MagicMock(return_value=None),
)
monkeypatch.setattr(provider, "_get_extra_mounts", lambda _thread_id: [])
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda *_args, **_kwargs: True)

first_id = provider.acquire("thread-lease")
second_id = provider.acquire("thread-lease")

assert first_id == second_id
assert provider._lease_counts[first_id] == 2

provider.release(first_id)

assert first_id in provider._sandboxes
assert first_id not in provider._warm_pool
assert provider._lease_counts[first_id] == 1

provider.release(first_id)

assert first_id not in provider._sandboxes
assert first_id in provider._warm_pool
assert first_id not in provider._lease_counts


@pytest.mark.anyio
async def test_acquire_async_uses_async_readiness_polling(monkeypatch):
"""AioSandboxProvider async creation must not use sync readiness polling."""
Expand Down
125 changes: 125 additions & 0 deletions backend/tests/test_feishu_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import json
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
Expand All @@ -17,6 +19,24 @@ def _run(coro):
loop.close()


class _FakeGetMessageResourceRequest:
@classmethod
def builder(cls):
return cls()

def message_id(self, _message_id):
return self

def file_key(self, _file_key):
return self

def type(self, _type):
return self

def build(self):
return object()


def test_feishu_on_message_plain_text():
bus = MessageBus()
config = {"app_id": "test", "app_secret": "test"}
Expand Down Expand Up @@ -103,6 +123,111 @@ async def go():
_run(go())


def test_feishu_receive_single_file_releases_sandbox_after_sync(tmp_path, monkeypatch):
async def go():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._GetMessageResourceRequest = _FakeGetMessageResourceRequest

response = MagicMock()
response.success.return_value = True
response.file = BytesIO(b"hello uploads")
response.file_name = "note.txt"

channel._api_client = MagicMock()
channel._api_client.im.v1.message_resource.get.return_value = response

paths = SimpleNamespace(
ensure_thread_dirs=MagicMock(),
sandbox_uploads_dir=MagicMock(return_value=tmp_path),
)
sandbox = MagicMock()
provider = MagicMock()
provider.acquire.return_value = "aio-1"
provider.get.return_value = sandbox

monkeypatch.setattr("app.channels.feishu.get_paths", lambda: paths)
monkeypatch.setattr("app.channels.feishu.get_sandbox_provider", lambda: provider)

result = await channel._receive_single_file("message-1", "file-key", "file", "thread-a")

assert result == "/mnt/user-data/uploads/note.txt"
assert (tmp_path / "note.txt").read_bytes() == b"hello uploads"
sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/note.txt", b"hello uploads")
provider.release.assert_called_once_with("aio-1")

_run(go())


def test_feishu_receive_single_file_skips_release_for_local_sandbox(tmp_path, monkeypatch):
async def go():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._GetMessageResourceRequest = _FakeGetMessageResourceRequest

response = MagicMock()
response.success.return_value = True
response.file = BytesIO(b"hello uploads")
response.file_name = "note.txt"

channel._api_client = MagicMock()
channel._api_client.im.v1.message_resource.get.return_value = response

paths = SimpleNamespace(
ensure_thread_dirs=MagicMock(),
sandbox_uploads_dir=MagicMock(return_value=tmp_path),
)
provider = MagicMock()
provider.acquire.return_value = "local"

monkeypatch.setattr("app.channels.feishu.get_paths", lambda: paths)
monkeypatch.setattr("app.channels.feishu.get_sandbox_provider", lambda: provider)

result = await channel._receive_single_file("message-1", "file-key", "file", "thread-local")

assert result == "/mnt/user-data/uploads/note.txt"
assert (tmp_path / "note.txt").read_bytes() == b"hello uploads"
provider.get.assert_not_called()
provider.release.assert_not_called()

_run(go())


def test_feishu_receive_single_file_releases_sandbox_when_sync_fails(tmp_path, monkeypatch):
async def go():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._GetMessageResourceRequest = _FakeGetMessageResourceRequest

response = MagicMock()
response.success.return_value = True
response.file = BytesIO(b"hello uploads")
response.file_name = "note.txt"

channel._api_client = MagicMock()
channel._api_client.im.v1.message_resource.get.return_value = response

paths = SimpleNamespace(
ensure_thread_dirs=MagicMock(),
sandbox_uploads_dir=MagicMock(return_value=tmp_path),
)
sandbox = MagicMock()
sandbox.update_file.side_effect = RuntimeError("sync failed")
provider = MagicMock()
provider.acquire.return_value = "aio-1"
provider.get.return_value = sandbox

monkeypatch.setattr("app.channels.feishu.get_paths", lambda: paths)
monkeypatch.setattr("app.channels.feishu.get_sandbox_provider", lambda: provider)

result = await channel._receive_single_file("message-1", "file-key", "file", "thread-a")

assert result == "Failed to obtain the [file]"
provider.release.assert_called_once_with("aio-1")

_run(go())


def test_feishu_on_message_extracts_image_and_file_keys():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
Expand Down
1 change: 1 addition & 0 deletions backend/tests/test_sandbox_orphan_reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _make_provider_for_reconciliation():
provider._thread_sandboxes = {}
provider._thread_locks = {}
provider._last_activity = {}
provider._lease_counts = {}
provider._warm_pool = {}
provider._shutdown_called = False
provider._idle_checker_stop = threading.Event()
Expand Down
Loading
Loading