Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_modality_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,24 @@ def test_base_py_imports_the_registry_derived_sets():
assert base._FILE_BEARING_CATEGORIES is registry.FILE_BEARING_CATEGORIES
assert base._TABULAR_FAMILY_CATEGORIES is registry.TABULAR_FAMILY_CATEGORIES
assert base._SELF_SUPERVISED_CATEGORIES is registry.SELF_SUPERVISED_CATEGORIES


def test_data_format_valid_and_matches_conventions():
"""Every spec carries a valid DataFormat (P3d), and conventions'
_data_format_for reads it from the registry (single source)."""
from tracebloc_ingestor.cli.conventions import _data_format_for
from tracebloc_ingestor.utils.constants import DataFormat

valid = set(DataFormat.get_all_formats())
for category, spec in REGISTRY.items():
assert (
spec.data_format in valid
), f"{category}: bad data_format {spec.data_format!r}"
assert _data_format_for(category) == spec.data_format


def test_transfer_present_iff_file_bearing():
"""The sidecar transfer factory is set exactly for file-bearing categories
(P3c invariant)."""
for category, spec in REGISTRY.items():
assert spec.is_file_bearing == (spec.transfer is not None)
105 changes: 58 additions & 47 deletions tracebloc_ingestor/cli/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,54 @@
from dataclasses import dataclass, field
from typing import Any, Dict, FrozenSet, List, Optional

from ..utils.constants import DataFormat, Intent, TaskCategory

from ..modalities.registry import spec_for
from ..utils.constants import TaskCategory

# ---------------------------------------------------------------------------
# Category groupings — used both here and by the entrypoint when deciding
# which sidecar paths matter. Single source of truth.
# ---------------------------------------------------------------------------

IMAGE_CATEGORIES: FrozenSet[str] = frozenset({
TaskCategory.IMAGE_CLASSIFICATION,
TaskCategory.OBJECT_DETECTION,
TaskCategory.KEYPOINT_DETECTION,
TaskCategory.SEMANTIC_SEGMENTATION,
})
IMAGE_CATEGORIES: FrozenSet[str] = frozenset(
{
TaskCategory.IMAGE_CLASSIFICATION,
TaskCategory.OBJECT_DETECTION,
TaskCategory.KEYPOINT_DETECTION,
TaskCategory.SEMANTIC_SEGMENTATION,
}
)

TEXT_CATEGORIES: FrozenSet[str] = frozenset({
TaskCategory.TEXT_CLASSIFICATION,
TaskCategory.TOKEN_CLASSIFICATION,
})
TEXT_CATEGORIES: FrozenSet[str] = frozenset(
{
TaskCategory.TEXT_CLASSIFICATION,
TaskCategory.TOKEN_CLASSIFICATION,
}
)

TABULAR_CATEGORIES: FrozenSet[str] = frozenset({
TaskCategory.TABULAR_CLASSIFICATION,
TaskCategory.TABULAR_REGRESSION,
})
TABULAR_CATEGORIES: FrozenSet[str] = frozenset(
{
TaskCategory.TABULAR_CLASSIFICATION,
TaskCategory.TABULAR_REGRESSION,
}
)

TIME_SERIES_CATEGORIES: FrozenSet[str] = frozenset({
TaskCategory.TIME_SERIES_FORECASTING,
})
TIME_SERIES_CATEGORIES: FrozenSet[str] = frozenset(
{
TaskCategory.TIME_SERIES_FORECASTING,
}
)

TIME_TO_EVENT_CATEGORIES: FrozenSet[str] = frozenset({
TaskCategory.TIME_TO_EVENT_PREDICTION,
})
TIME_TO_EVENT_CATEGORIES: FrozenSet[str] = frozenset(
{
TaskCategory.TIME_TO_EVENT_PREDICTION,
}
)

MLM_CATEGORIES: FrozenSet[str] = frozenset({
TaskCategory.MASKED_LANGUAGE_MODELING,
})
MLM_CATEGORIES: FrozenSet[str] = frozenset(
{
TaskCategory.MASKED_LANGUAGE_MODELING,
}
)

# Categories where the label is a numeric prediction target rather than
# class metadata. The schema requires `label.policy` for these so the raw
Expand Down Expand Up @@ -101,12 +113,18 @@
# framework's own samples round-trip through the documented happy-path
# config with zero overrides. Production users with differently-sized data
# should set `spec.file_options` in their YAML.
TaskCategory.IMAGE_CLASSIFICATION: {"target_size": [256, 256], "extension": ".jpeg"},
TaskCategory.SEMANTIC_SEGMENTATION: {"target_size": [512, 512], "extension": ".jpg"},
TaskCategory.OBJECT_DETECTION: {"target_size": [1920, 1080], "extension": ".jpg"},
TaskCategory.IMAGE_CLASSIFICATION: {
"target_size": [256, 256],
"extension": ".jpeg",
},
TaskCategory.SEMANTIC_SEGMENTATION: {
"target_size": [512, 512],
"extension": ".jpg",
},
TaskCategory.OBJECT_DETECTION: {"target_size": [1920, 1080], "extension": ".jpg"},
# keypoint_detection: no target_size default — the customer's pose model
# dictates input resolution, so the schema requires it top-level.
TaskCategory.KEYPOINT_DETECTION: {"extension": ".jpg"},
TaskCategory.KEYPOINT_DETECTION: {"extension": ".jpg"},
}

DEFAULT_TEXT_FILE_OPTIONS: Dict[str, Any] = {
Expand All @@ -122,6 +140,7 @@
# Resolved configuration — what the entrypoint actually consumes.
# ---------------------------------------------------------------------------


@dataclass
class ResolvedConfig:
"""A fully-resolved ingest configuration.
Expand Down Expand Up @@ -183,6 +202,7 @@ class ResolvedConfig:
# Resolver
# ---------------------------------------------------------------------------


def resolve(config: Dict[str, Any]) -> ResolvedConfig:
"""Translate a validated ingest.yaml dict into a :class:`ResolvedConfig`.

Expand Down Expand Up @@ -302,25 +322,16 @@ def resolve(config: Dict[str, Any]) -> ResolvedConfig:
# Helpers
# ---------------------------------------------------------------------------


def _data_format_for(category: str) -> str:
"""Map ``category`` to the ``DataFormat`` value the framework expects."""
if category in IMAGE_CATEGORIES:
return DataFormat.IMAGE
if category in TEXT_CATEGORIES:
return DataFormat.TEXT
if (
category in TABULAR_CATEGORIES
or category in TIME_SERIES_CATEGORIES
or category in TIME_TO_EVENT_CATEGORIES
):
return DataFormat.TABULAR
if category in MLM_CATEGORIES:
return DataFormat.TEXT
raise ValueError(
f"Unknown category {category!r}; cannot derive data_format. "
"If this is a new category, add it to the relevant CATEGORY set "
"in conventions.py and to the schema enum."
)
"""Map ``category`` to the ``DataFormat`` value the framework expects.

Reads the single source of truth — the ModalityRegistry — rather than the
per-format frozensets in this module (structural refactor backend#796,
P3d). ``spec_for`` raises ``ValueError`` on an unknown category, same as
the previous ladder did.
"""
return spec_for(category).data_format


def _default_file_options_for(category: str) -> Dict[str, Any]:
Expand Down
16 changes: 14 additions & 2 deletions tracebloc_ingestor/modalities/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@

from typing import Dict

from ..utils.constants import TaskCategory
from ..utils.constants import DataFormat, TaskCategory
from . import transfer as t
from . import validators as v
from .spec import ModalitySpec

# One entry per supported category — the single source of truth. P3a: the three
# behavior flags (was three frozensets in ingestors/base.py). P3b: the
# validator factory (was the map_validators if/elif arm). P3c: the sidecar
# transfer factory (was the map_file_transfer if/elif arm).
# transfer factory (was the map_file_transfer if/elif arm). P3d: data_format
# (was the 6-frozenset ladder in conventions._data_format_for).
_SPECS = (
# File-bearing categories (per-row sidecar files under SRC_PATH).
ModalitySpec(
TaskCategory.IMAGE_CLASSIFICATION,
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=False,
data_format=DataFormat.IMAGE,
build_validators=v.image_classification,
transfer=t.image_classification,
),
Expand All @@ -36,6 +38,7 @@
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=False,
data_format=DataFormat.IMAGE,
build_validators=v.object_detection,
transfer=t.object_detection,
),
Expand All @@ -44,6 +47,7 @@
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=False,
data_format=DataFormat.IMAGE,
build_validators=v.keypoint_detection,
transfer=t.keypoint_detection,
),
Expand All @@ -52,6 +56,7 @@
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=False,
data_format=DataFormat.IMAGE,
build_validators=v.semantic_segmentation,
transfer=t.semantic_segmentation,
),
Expand All @@ -60,6 +65,7 @@
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=False,
data_format=DataFormat.TEXT,
build_validators=v.text_classification,
transfer=t.text_classification,
),
Expand All @@ -68,6 +74,7 @@
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=False,
data_format=DataFormat.TEXT,
build_validators=v.token_classification,
transfer=t.token_classification,
),
Expand All @@ -77,6 +84,7 @@
is_file_bearing=True,
is_tabular_family=False,
is_self_supervised=True,
data_format=DataFormat.TEXT,
build_validators=v.masked_language_modeling,
transfer=t.masked_language_modeling,
),
Expand All @@ -86,27 +94,31 @@
is_file_bearing=False,
is_tabular_family=True,
is_self_supervised=False,
data_format=DataFormat.TABULAR,
build_validators=v.tabular_classification,
),
ModalitySpec(
TaskCategory.TABULAR_REGRESSION,
is_file_bearing=False,
is_tabular_family=True,
is_self_supervised=False,
data_format=DataFormat.TABULAR,
build_validators=v.tabular_regression,
),
ModalitySpec(
TaskCategory.TIME_SERIES_FORECASTING,
is_file_bearing=False,
is_tabular_family=True,
is_self_supervised=False,
data_format=DataFormat.TABULAR,
build_validators=v.time_series_forecasting,
),
ModalitySpec(
TaskCategory.TIME_TO_EVENT_PREDICTION,
is_file_bearing=False,
is_tabular_family=True,
is_self_supervised=False,
data_format=DataFormat.TABULAR,
build_validators=v.time_to_event_prediction,
),
)
Expand Down
4 changes: 4 additions & 0 deletions tracebloc_ingestor/modalities/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class ModalitySpec:

Attributes:
category: the ``TaskCategory`` value this spec describes.
data_format: the ``DataFormat`` the framework expects for this category
(P3d). Read by ``conventions._data_format_for`` (was a 6-frozenset
ladder there).
build_validators: ``(file_options) -> [validators]`` — the validator
set this category runs (P3b). Replaces the corresponding
``map_validators`` if/elif arm; the factory bodies live in
Expand Down Expand Up @@ -61,6 +64,7 @@ class ModalitySpec:
is_file_bearing: bool
is_tabular_family: bool
is_self_supervised: bool
data_format: str
build_validators: Callable[[Dict[str, Any]], List]
transfer: Optional[
Callable[[Dict[str, Any], Dict[str, Any]], Optional[Dict[str, Any]]]
Expand Down
Loading