Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion backend/src/api/routes/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,12 @@ async def split_clip(
async def merge_clips(
task_id: str, request: Request, db: AsyncSession = Depends(get_db)
):
"""Merge multiple clips into one clip."""
"""Merge clips synchronously. Kept for back-compat — prefer /merge_async.

The ffmpeg concat-encode regularly exceeds the ALB idle timeout for
multi-clip composites and surfaces as a 504 here. New callers should
use the async variant.
"""
try:
payload = await request.json()
clip_ids = payload.get("clip_ids") or []
Expand All @@ -678,6 +683,107 @@ async def merge_clips(
raise HTTPException(status_code=500, detail=f"Error merging clips: {str(e)}")


@router.post("/{task_id}/clips/merge_async", status_code=202)
async def merge_clips_async(
task_id: str, request: Request, db: AsyncSession = Depends(get_db)
):
"""Enqueue a merge job and return immediately.

Poll status via GET /tasks/{task_id}/clips/merge_jobs/{merge_job_id}.
Validation (ownership, clip existence) runs synchronously so bad
requests fail fast instead of burning a worker slot.
"""
try:
payload = await request.json()
clip_ids = payload.get("clip_ids") or []
if not isinstance(clip_ids, list):
Comment thread
brad07 marked this conversation as resolved.
raise HTTPException(status_code=400, detail="clip_ids must be an array")
Comment thread
brad07 marked this conversation as resolved.
if len(clip_ids) < 2:
raise HTTPException(
status_code=400, detail="At least two clips are required to merge"
)

task_service = TaskService(db)
await _require_task_owner(request, task_service, db, task_id)

for clip_id in clip_ids:
clip = await task_service.clip_repo.get_clip_by_id(db, clip_id)
if not clip or clip["task_id"] != task_id:
raise HTTPException(
status_code=404, detail=f"Clip {clip_id} not found on task"
)

merge_job_id = await JobQueue.enqueue_job(
"merge_clips_job", task_id, clip_ids
)
logger.info(
f"Enqueued merge job {merge_job_id} task={task_id} clips={len(clip_ids)}"
)
return {"merge_job_id": merge_job_id, "status": "queued"}
except HTTPException:
Comment thread
brad07 marked this conversation as resolved.
raise
except Exception as e:
logger.error(f"Error enqueueing merge: {e}")
raise HTTPException(
status_code=500, detail=f"Error enqueueing merge: {str(e)}"
)


@router.get("/{task_id}/clips/merge_jobs/{merge_job_id}")
async def get_merge_job(
task_id: str,
merge_job_id: str,
request: Request,
db: AsyncSession = Depends(get_db),
):
"""Poll a queued merge.

Status values mirror arq's JobStatus enum (deferred | queued |
in_progress | complete | not_found). On `complete` the response
carries either `clip_id` + `message` (success) or `error` (worker
exception, surfaced as the str() of the raised exception).
"""
Comment thread
brad07 marked this conversation as resolved.
try:
task_service = TaskService(db)
await _require_task_owner(request, task_service, db, task_id)

status = await JobQueue.get_job_status(merge_job_id)
if status is None:
raise HTTPException(
status_code=404, detail=f"Merge job {merge_job_id} not found"
)

status_str = str(status).split(".")[-1].lower()
response: Dict[str, Any] = {
"merge_job_id": merge_job_id,
"status": status_str,
}

if status_str == "complete":
try:
result = await JobQueue.get_job_result(merge_job_id)
if isinstance(result, dict):
response["clip_id"] = result.get("clip_id")
response["message"] = result.get("message")
else:
response["error"] = (
f"Unexpected worker result type: {type(result).__name__}"
)
except Exception as exc:
# arq raises the original worker exception when the job
# ended in failure; expose its string form to the caller.
response["error"] = str(exc)

return response
except HTTPException:
raise
except Exception as e:
logger.error(f"Error fetching merge job status: {e}")
raise HTTPException(
status_code=500, detail=f"Error fetching merge job status: {str(e)}"
)


@router.patch("/{task_id}/clips/{clip_id}/captions")
async def update_clip_captions(
task_id: str, clip_id: str, request: Request, db: AsyncSession = Depends(get_db)
Expand Down
38 changes: 37 additions & 1 deletion backend/src/workers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,42 @@ async def clip_ready_callback(
# Error will be caught by arq and task status will be updated
raise

async def merge_clips_job(
ctx: Dict[str, Any],
task_id: str,
clip_ids: list[str],
) -> Dict[str, Any]:
"""
Background worker task to merge clips.

The synchronous /tasks/{task_id}/clips/merge endpoint blocks the HTTP
request for the full ffmpeg concat-encode duration, which routinely
exceeds the ALB idle timeout (60s default, 300s after the band-aid
bump) and surfaces as a 504 to the caller. This worker variant is
enqueued by /tasks/{task_id}/clips/merge_async and polled via
/tasks/{task_id}/clips/merge_jobs/{job_id} so callers never hold an
HTTP connection open for the encode.

Returns the same dict shape as TaskService.merge_clips so arq's
job result storage carries the merged_clip_id straight to the poller.
"""
from ..database import AsyncSessionLocal
from ..runtime_settings import load_runtime_settings_cache
from ..services.task_service import TaskService

set_trace_id(f"merge-{task_id}")
logger.info(f"Worker merging {len(clip_ids)} clips for task {task_id}")

async with AsyncSessionLocal() as db:
await load_runtime_settings_cache(db)
task_service = TaskService(db)
result = await task_service.merge_clips(task_id, clip_ids)
logger.info(
f"Merge complete task={task_id} merged_clip_id={result.get('clip_id')}"
)
return result


# Worker configuration for arq
class WorkerSettings:
"""Configuration for arq worker."""
Expand All @@ -128,7 +164,7 @@ class WorkerSettings:
config = Config()

# Functions to run
functions = [process_video_task]
functions = [process_video_task, merge_clips_job]
queue_name = "supoclip_tasks"

# Redis settings from environment
Expand Down
152 changes: 151 additions & 1 deletion backend/tests/integration/test_health_and_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from unittest.mock import AsyncMock, patch

import pytest

from tests.fixtures.factories import create_source, create_task, create_user
from tests.fixtures.factories import (
create_clip,
create_source,
create_task,
create_user,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -99,3 +106,146 @@ async def test_upload_video_uses_runtime_config_temp_dir(
payload = response.json()
saved_name = payload["video_path"].removeprefix("upload://")
assert (tmp_path / "uploads" / saved_name).exists()


@pytest.mark.asyncio
async def test_merge_async_enqueues_and_returns_job_id(
client, db_session, auth_headers
):
await create_user(db_session, user_id="user-1", email="owner@example.com")
source = await create_source(db_session, title="Owner source")
task = await create_task(db_session, user_id="user-1", source_id=source["id"])
clip_a = await create_clip(db_session, task_id=task["id"])
clip_b = await create_clip(db_session, task_id=task["id"])

with patch(
"src.api.routes.tasks.JobQueue.enqueue_job",
new=AsyncMock(return_value="merge-job-xyz"),
) as enqueue:
response = await client.post(
f"/tasks/{task['id']}/clips/merge_async",
headers=auth_headers,
json={"clip_ids": [clip_a["id"], clip_b["id"]]},
)

assert response.status_code == 202
payload = response.json()
assert payload == {"merge_job_id": "merge-job-xyz", "status": "queued"}
enqueue.assert_awaited_once_with(
"merge_clips_job", task["id"], [clip_a["id"], clip_b["id"]]
)


@pytest.mark.asyncio
async def test_merge_async_rejects_unknown_clip(client, db_session, auth_headers):
await create_user(db_session, user_id="user-1", email="owner@example.com")
source = await create_source(db_session, title="Owner source")
task = await create_task(db_session, user_id="user-1", source_id=source["id"])
clip = await create_clip(db_session, task_id=task["id"])

# Don't even hit the queue if validation fails — guards against
# burning a worker slot to discover a typo.
with patch(
"src.api.routes.tasks.JobQueue.enqueue_job",
new=AsyncMock(return_value="should-not-be-called"),
) as enqueue:
response = await client.post(
f"/tasks/{task['id']}/clips/merge_async",
headers=auth_headers,
json={"clip_ids": [clip["id"], "ghost-clip-id"]},
)

assert response.status_code == 404
assert "ghost-clip-id" in response.json()["detail"]
enqueue.assert_not_awaited()


@pytest.mark.asyncio
async def test_merge_async_rejects_single_clip(client, db_session, auth_headers):
await create_user(db_session, user_id="user-1", email="owner@example.com")
source = await create_source(db_session, title="Owner source")
task = await create_task(db_session, user_id="user-1", source_id=source["id"])
clip = await create_clip(db_session, task_id=task["id"])

response = await client.post(
f"/tasks/{task['id']}/clips/merge_async",
headers=auth_headers,
json={"clip_ids": [clip["id"]]},
)

assert response.status_code == 400
assert "two" in response.json()["detail"].lower()


@pytest.mark.asyncio
async def test_get_merge_job_returns_completion_result(
client, db_session, auth_headers
):
await create_user(db_session, user_id="user-1", email="owner@example.com")
source = await create_source(db_session, title="Owner source")
task = await create_task(db_session, user_id="user-1", source_id=source["id"])

with patch(
"src.api.routes.tasks.JobQueue.get_job_status",
new=AsyncMock(return_value="JobStatus.complete"),
), patch(
"src.api.routes.tasks.JobQueue.get_job_result",
new=AsyncMock(return_value={"clip_id": "merged-1", "message": "ok"}),
):
response = await client.get(
f"/tasks/{task['id']}/clips/merge_jobs/job-abc",
headers=auth_headers,
)

assert response.status_code == 200
payload = response.json()
assert payload == {
"merge_job_id": "job-abc",
"status": "complete",
"clip_id": "merged-1",
"message": "ok",
}


@pytest.mark.asyncio
async def test_get_merge_job_surfaces_worker_error(client, db_session, auth_headers):
await create_user(db_session, user_id="user-1", email="owner@example.com")
source = await create_source(db_session, title="Owner source")
task = await create_task(db_session, user_id="user-1", source_id=source["id"])

with patch(
"src.api.routes.tasks.JobQueue.get_job_status",
new=AsyncMock(return_value="complete"),
), patch(
"src.api.routes.tasks.JobQueue.get_job_result",
new=AsyncMock(side_effect=RuntimeError("ffmpeg exit 254")),
):
response = await client.get(
f"/tasks/{task['id']}/clips/merge_jobs/job-bad",
headers=auth_headers,
)

assert response.status_code == 200
payload = response.json()
assert payload["status"] == "complete"
assert "ffmpeg exit 254" in payload["error"]


@pytest.mark.asyncio
async def test_get_merge_job_returns_404_when_unknown(
client, db_session, auth_headers
):
await create_user(db_session, user_id="user-1", email="owner@example.com")
source = await create_source(db_session, title="Owner source")
task = await create_task(db_session, user_id="user-1", source_id=source["id"])

with patch(
"src.api.routes.tasks.JobQueue.get_job_status",
new=AsyncMock(return_value=None),
):
response = await client.get(
f"/tasks/{task['id']}/clips/merge_jobs/ghost",
headers=auth_headers,
)

assert response.status_code == 404
Loading