Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 67 additions & 18 deletions ami/main/admin.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from typing import Any

import pydantic
from django.contrib import admin
from django.db import models
from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from django.http.response import HttpResponse
from django.template.defaultfilters import filesizeformat
from django.template.response import TemplateResponse
from django.urls import reverse
from django.utils.formats import number_format
from guardian.admin import GuardedModelAdmin

import ami.utils
from ami import tasks
from ami.jobs.models import Job
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm
from ami.ml.post_processing.small_size_filter import SmallSizeFilterConfig
from ami.ml.tasks import remove_duplicate_classifications

from .models import (
Expand Down Expand Up @@ -652,24 +658,67 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
)

@admin.action(description="Run Small Size Filter post-processing task (async)")
def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
jobs = []
for collection in queryset:
job = Job.objects.create(
name=f"Post-processing: SmallSizeFilter on Capture Set {collection.pk}",
project=collection.project,
job_type_key="post_processing",
params={
"task": "small_size_filter",
"config": {
"source_image_collection_id": collection.pk,
},
},
)
job.enqueue()
jobs.append(job.pk)

self.message_user(request, f"Queued Small Size Filter for {queryset.count()} capture set(s). Jobs: {jobs}")
def run_small_size_filter(
Comment thread
mihow marked this conversation as resolved.
Outdated
self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]
) -> HttpResponse | None:
if request.POST.get("confirm"):
form = SmallSizeFilterActionForm(request.POST)
if not form.is_valid():
return self._render_small_size_filter_confirmation(request, queryset, form)

cfg = form.to_config()
jobs = []
for collection in queryset:
try:
validated = SmallSizeFilterConfig(
**cfg,
source_image_collection_id=collection.pk,
)
except pydantic.ValidationError as exc:
self.message_user(
request,
f"Bad config for capture set {collection.pk}: {exc}",
level="error",
)
Comment thread
mihow marked this conversation as resolved.
Outdated
continue
job = Job.objects.create(
name=f"Post-processing: SmallSizeFilter on Capture Set {collection.pk}",
project=collection.project,
job_type_key="post_processing",
params={"task": "small_size_filter", "config": validated.dict()},
)
job.enqueue()
jobs.append(job.pk)

self.message_user(request, f"Queued Small Size Filter for {len(jobs)} capture set(s). Jobs: {jobs}")
return None

return self._render_small_size_filter_confirmation(request, queryset, SmallSizeFilterActionForm())

def _render_small_size_filter_confirmation(
self,
request: HttpRequest,
queryset: QuerySet[SourceImageCollection],
form: SmallSizeFilterActionForm,
) -> TemplateResponse:
return TemplateResponse(
request,
"admin/post_processing/confirmation.html",
{
**self.admin_site.each_context(request),
"title": "Run Small Size Filter",
"task_label": "Small Size Filter",
"form": form,
"selected_count": queryset.count(),
"selected_pks": [str(pk) for pk in queryset.values_list("pk", flat=True)],
"action_name": "run_small_size_filter",
"submit_label": "Run Small Size Filter",
"changelist_url": reverse("admin:main_sourceimagecollection_changelist"),
"model_meta": self.model._meta,
"opts": self.model._meta,
"action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME,
},
)

actions = [
populate_collection,
Expand Down
Empty file.
25 changes: 25 additions & 0 deletions ami/ml/post_processing/admin/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Form base for admin actions that trigger post-processing tasks.

Each post-processing task surfaces its tunable knobs as a Django form. The
form's ``cleaned_data`` becomes the ``config`` payload on the resulting Job
(after validation against the task's pydantic ``config_schema``).

Algorithm scope (which queryset/events/collection the action runs against)
lives outside the form because it varies per admin entry-point.
"""
from __future__ import annotations

from django import forms


class BasePostProcessingActionForm(forms.Form):
"""Marker base for post-processing admin action forms.

Subclasses declare task-specific fields. Override ``to_config()`` if the
1:1 ``cleaned_data → config`` mapping needs adjustment (e.g. drop empty
optional fields, derive computed values, rename keys).
"""

def to_config(self) -> dict:
"""Return ``cleaned_data`` shaped for ``Job.params['config']``."""
return dict(self.cleaned_data)
28 changes: 28 additions & 0 deletions ami/ml/post_processing/admin/small_size_filter_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from django import forms

from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm
from ami.ml.post_processing.small_size_filter import SmallSizeFilterConfig


class SmallSizeFilterActionForm(BasePostProcessingActionForm):
"""Knobs surfaced when an admin triggers Small Size Filter."""

size_threshold = forms.FloatField(
label="Size threshold",
initial=SmallSizeFilterConfig.__fields__["size_threshold"].default,
min_value=0.0,
max_value=1.0,
help_text=(
"Minimum bounding-box area as a fraction of the source image area "
"(width × height). Detections smaller than this are flagged as "
"'Not identifiable'. Default 0.0008 ≈ 0.08% of frame area."
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
),
)

def clean_size_threshold(self) -> float:
v = self.cleaned_data["size_threshold"]
if not (0.0 < v < 1.0):
raise forms.ValidationError("size_threshold must be in (0, 1) exclusive.")
return v
14 changes: 11 additions & 3 deletions ami/ml/post_processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import typing
from typing import Any, Optional

import pydantic

from ami.ml.models import Algorithm
from ami.ml.models.algorithm import AlgorithmTaskType

Expand All @@ -13,15 +15,21 @@
class BasePostProcessingTask(abc.ABC):
"""
Abstract base class for all post-processing tasks.

Subclasses must declare a Pydantic ``config_schema`` describing the shape of
``Job.params['config']``. Config is validated at task construction so bad
payloads fail fast in worker logs (and earlier still — admin triggers and
other callers should validate via the same schema before enqueueing a Job).
"""

# Each task must override these
key: str
name: str
config_schema: type[pydantic.BaseModel]

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
required_attrs = ["key", "name"]
required_attrs = ["key", "name", "config_schema"]
for attr in required_attrs:
if not hasattr(cls, attr) or getattr(cls, attr) is None:
raise TypeError(f"{cls.__name__} must define '{attr}' class attribute")
Expand All @@ -33,7 +41,7 @@ def __init__(
**config: Any,
):
self.job = job
self.config = config
self.config: pydantic.BaseModel = self.config_schema(**config)
# Choose the right logger
if logger is not None:
self.logger = logger
Expand All @@ -52,7 +60,7 @@ def __init__(
)
self.algorithm: Algorithm = algorithm

self.logger.info(f"Initialized {self.name } with config={self.config}, job={job}")
self.logger.info(f"Initialized {self.name} with config={self.config.dict()}, job={job}")

def update_progress(self, progress: float):
"""
Expand Down
27 changes: 19 additions & 8 deletions ami/ml/post_processing/small_size_filter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
import pydantic
from django.utils import timezone

from ami.main.models import Classification, Detection, Occurrence, SourceImageCollection, Taxon, TaxonRank
from ami.ml.post_processing.base import BasePostProcessingTask
from ami.ml.schemas import BoundingBox


class SmallSizeFilterConfig(pydantic.BaseModel):
source_image_collection_id: int
size_threshold: float = 0.0008

@pydantic.validator("size_threshold")
def _threshold_in_unit_interval(cls, v: float) -> float:
if not (0.0 < v < 1.0):
raise ValueError("size_threshold must be in (0, 1) exclusive")
return v

class Config:
extra = "forbid"


class SmallSizeFilterTask(BasePostProcessingTask):
key = "small_size_filter"
name = "Small size filter"
config_schema = SmallSizeFilterConfig

def run(self) -> None:
# Could we use a pydantic model for config validation if it's just for this task?
threshold = self.config.get("size_threshold", 0.0008)
collection_id = self.config.get("source_image_collection_id")
config: SmallSizeFilterConfig = self.config # type: ignore[assignment]
threshold = config.size_threshold
collection_id = config.source_image_collection_id

# Get or create the "Not identifiable" taxon
not_identifiable_taxon, _ = Taxon.objects.get_or_create(
Expand All @@ -24,11 +40,6 @@ def run(self) -> None:
)
self.logger.info(f"=== Starting {self.name} ===")

if not collection_id:
msg = "Missing required config param: source_image_collection_id"
self.logger.error(msg)
raise ValueError(msg)

try:
collection = SourceImageCollection.objects.get(pk=collection_id)
self.logger.info(f"Loaded SourceImageCollection {collection_id} (Project={collection.project})")
Expand Down
Empty file.
46 changes: 46 additions & 0 deletions ami/ml/post_processing/tests/test_admin_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Tests for ``BasePostProcessingActionForm`` + concrete ``SmallSizeFilterActionForm``."""
from django.test import TestCase

from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm
from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm


class _OneFieldForm(BasePostProcessingActionForm):
from django import forms

threshold = forms.FloatField(initial=0.5)


class TestBasePostProcessingActionForm(TestCase):
def test_to_config_returns_cleaned_data(self):
form = _OneFieldForm(data={"threshold": "0.25"})
self.assertTrue(form.is_valid())
self.assertEqual(form.to_config(), {"threshold": 0.25})


class TestSmallSizeFilterActionForm(TestCase):
def test_default_initial_matches_config_default(self):
form = SmallSizeFilterActionForm()
self.assertEqual(form.fields["size_threshold"].initial, 0.0008)

def test_valid_threshold_passes(self):
form = SmallSizeFilterActionForm(data={"size_threshold": "0.001"})
self.assertTrue(form.is_valid())
self.assertEqual(form.to_config(), {"size_threshold": 0.001})

def test_threshold_above_one_rejected(self):
form = SmallSizeFilterActionForm(data={"size_threshold": "1.5"})
self.assertFalse(form.is_valid())
self.assertIn("size_threshold", form.errors)

def test_threshold_zero_rejected(self):
# 0.0 is excluded (open interval); django's min_value=0.0 admits zero,
# so the clean_size_threshold check is the gate.
form = SmallSizeFilterActionForm(data={"size_threshold": "0.0"})
self.assertFalse(form.is_valid())
self.assertIn("size_threshold", form.errors)

def test_threshold_at_one_rejected(self):
form = SmallSizeFilterActionForm(data={"size_threshold": "1.0"})
self.assertFalse(form.is_valid())
self.assertIn("size_threshold", form.errors)
45 changes: 45 additions & 0 deletions ami/ml/post_processing/tests/test_base_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for the pydantic ``config_schema`` contract on ``BasePostProcessingTask``."""
import pydantic
import pytest
from django.test import TestCase

from ami.ml.post_processing.base import BasePostProcessingTask
from ami.ml.post_processing.small_size_filter import SmallSizeFilterConfig, SmallSizeFilterTask


class TestConfigSchemaContract(TestCase):
"""``__init_subclass__`` enforces ``config_schema``; ``__init__`` validates against it."""

def test_subclass_without_config_schema_raises(self):
with pytest.raises(TypeError, match="config_schema"):

class Missing(BasePostProcessingTask):
key = "missing"
name = "Missing schema"

def run(self) -> None:
pass

def test_valid_config_builds_basemodel_instance(self):
task = SmallSizeFilterTask(source_image_collection_id=1, size_threshold=0.001)
self.assertIsInstance(task.config, SmallSizeFilterConfig)
config: SmallSizeFilterConfig = task.config # type: ignore[assignment]
self.assertEqual(config.size_threshold, 0.001)
self.assertEqual(config.source_image_collection_id, 1)

def test_default_value_applies_when_omitted(self):
task = SmallSizeFilterTask(source_image_collection_id=1)
config: SmallSizeFilterConfig = task.config # type: ignore[assignment]
self.assertEqual(config.size_threshold, 0.0008)

def test_invalid_config_raises_at_init(self):
with pytest.raises(pydantic.ValidationError):
SmallSizeFilterTask(source_image_collection_id=1, size_threshold=2.0)

def test_missing_required_field_raises(self):
with pytest.raises(pydantic.ValidationError):
SmallSizeFilterTask(size_threshold=0.001)

def test_unknown_keys_rejected(self):
with pytest.raises(pydantic.ValidationError):
SmallSizeFilterTask(source_image_collection_id=1, unknown_field="oops")
Loading
Loading