Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tests/embedder/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import importlib.util
import platform
from typing import Any

import pytest
import torch
Expand Down Expand Up @@ -77,9 +80,9 @@ def on_windows() -> bool:
]


def create_sentence_transformer_config(**kwargs) -> SentenceTransformerEmbeddingConfig:
def create_sentence_transformer_config(**kwargs: Any) -> SentenceTransformerEmbeddingConfig:
"""Helper function to create SentenceTransformer config with defaults."""
defaults = {
defaults: dict[str, Any] = {
"model_name": "sergeyzh/rubert-tiny-turbo",
"batch_size": 4,
"device": "cpu",
Expand All @@ -90,9 +93,9 @@ def create_sentence_transformer_config(**kwargs) -> SentenceTransformerEmbedding
return SentenceTransformerEmbeddingConfig(**defaults)


def create_openai_config(**kwargs) -> OpenaiEmbeddingConfig:
def create_openai_config(**kwargs: Any) -> OpenaiEmbeddingConfig:
"""Helper function to create OpenAI config with defaults."""
defaults = {
defaults: dict[str, Any] = {
"model_name": "text-embedding-3-small",
"batch_size": 2,
"use_cache": False,
Expand All @@ -103,9 +106,9 @@ def create_openai_config(**kwargs) -> OpenaiEmbeddingConfig:
return OpenaiEmbeddingConfig(**defaults)


def create_vllm_config(**kwargs) -> VllmEmbeddingConfig:
def create_vllm_config(**kwargs: Any) -> VllmEmbeddingConfig:
"""Helper function to create VllmEmbeddingConfig with test-friendly defaults."""
defaults = {
defaults: dict[str, Any] = {
"model_name": "BAAI/bge-base-en-v1.5",
"batch_size": 4,
"use_cache": False,
Expand All @@ -117,5 +120,5 @@ def create_vllm_config(**kwargs) -> VllmEmbeddingConfig:


@pytest.fixture(autouse=True)
def _autouse_fake_openai_embedding(patch_openai_embedding_backend):
def _autouse_fake_openai_embedding(patch_openai_embedding_backend: None) -> None:
"""Within tests/embedder/, every OpenaiEmbeddingConfig resolves to FakeOpenaiEmbeddingBackend."""
10 changes: 5 additions & 5 deletions tests/embedder/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def embedder(self, embedder_config: EmbedderConfig) -> Embedder:
"""Create an Embedder instance for testing."""
return Embedder(embedder_config)

def test_embedding_calculation(self, embedder: Embedder):
def test_embedding_calculation(self, embedder: Embedder) -> None:
"""Test basic embedding calculation functionality."""
test_utterances = ["Hello world", "Test sentence", "Another example"]

Expand All @@ -34,7 +34,7 @@ def test_embedding_calculation(self, embedder: Embedder):
if hasattr(embedder.config, "similarity_fn_name"):
assert np.allclose(np.linalg.norm(embeddings, axis=1), 1.0, atol=1e-5) # normalized

def test_embedding_reproducibility(self, embedder: Embedder):
def test_embedding_reproducibility(self, embedder: Embedder) -> None:
"""Test that embeddings are reproducible for same input."""
test_utterances = ["Hello world", "Test sentence"]

Expand All @@ -43,13 +43,13 @@ def test_embedding_reproducibility(self, embedder: Embedder):

np.testing.assert_allclose(embeddings1, embeddings2, rtol=1e-5)

def test_single_utterance(self, embedder: Embedder):
def test_single_utterance(self, embedder: Embedder) -> None:
"""Test embedding calculation for single utterance."""
embeddings = embedder.embed(["Single test sentence"])
assert embeddings.shape[0] == 1
assert embeddings.shape[1] > 0

def test_similarity_calculation(self, embedder: Embedder):
def test_similarity_calculation(self, embedder: Embedder) -> None:
"""Test similarity calculation between embeddings."""
utterances = ["Hello world", "Test sentence", "Another test"]
embeddings = embedder.embed(utterances)
Expand All @@ -62,7 +62,7 @@ def test_similarity_calculation(self, embedder: Embedder):
assert np.all(sim_matrix >= -1.0)
assert np.all(sim_matrix <= 1.0)

def test_similarity_symmetry(self, embedder: Embedder):
def test_similarity_symmetry(self, embedder: Embedder) -> None:
"""Test that similarity is symmetric."""
utterances = ["Hello world", "Test sentence"]
embeddings = embedder.embed(utterances)
Expand Down
10 changes: 5 additions & 5 deletions tests/embedder/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class TestEmbedderCaching:
"""Test caching functionality for different embedder backends."""

def test_caching_consistency(self, embedder_config: EmbedderConfig):
def test_caching_consistency(self, embedder_config: EmbedderConfig) -> None:
"""Test that caching produces consistent results when enabled."""
# Create config with caching enabled
if hasattr(embedder_config, "model_copy"):
Expand All @@ -40,7 +40,7 @@ def test_caching_consistency(self, embedder_config: EmbedderConfig):
# Verify results are identical
np.testing.assert_allclose(embeddings1, embeddings2, rtol=1e-5)

def test_caching_disabled_consistency(self, embedder_config: EmbedderConfig):
def test_caching_disabled_consistency(self, embedder_config: EmbedderConfig) -> None:
"""Test behavior when caching is disabled."""
# Ensure caching is disabled
if hasattr(embedder_config, "model_copy"):
Expand All @@ -63,7 +63,7 @@ def test_caching_disabled_consistency(self, embedder_config: EmbedderConfig):
class TestSentenceTransformerCachingSpecific:
"""Test caching functionality specific to SentenceTransformer backend."""

def test_caching_performance_improvement(self):
def test_caching_performance_improvement(self) -> None:
"""Test that caching provides performance improvement."""
config = create_sentence_transformer_config(use_cache=True)
embedder = Embedder(config)
Expand All @@ -83,7 +83,7 @@ def test_caching_performance_improvement(self):
# but we can at least verify the caching mechanism works
assert embeddings1.shape == embeddings2.shape

def test_different_inputs_no_cache_collision(self):
def test_different_inputs_no_cache_collision(self) -> None:
"""Test that different inputs don't collide in cache."""
config = create_sentence_transformer_config(use_cache=True)
embedder = Embedder(config)
Expand All @@ -94,7 +94,7 @@ def test_different_inputs_no_cache_collision(self):
# Different inputs should produce different embeddings
assert not np.allclose(embeddings1, embeddings2, rtol=1e-3)

def test_cache_with_different_prompts(self):
def test_cache_with_different_prompts(self) -> None:
"""Test that prompts are considered in caching."""
config = create_sentence_transformer_config(
use_cache=True,
Expand Down
54 changes: 39 additions & 15 deletions tests/embedder/test_dump_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
import pytest

from autointent._wrappers.embedder import Embedder
from autointent.configs import SentenceTransformerEmbeddingConfig
from autointent.configs import (
OpenaiEmbeddingConfig,
SentenceTransformerEmbeddingConfig,
)
from tests.conftest import tiny_sentence_transformer

from .conftest import backend_configs
Expand All @@ -17,7 +20,7 @@
from autointent.configs import EmbedderConfig


def test_load_from_disk(on_windows):
def test_load_from_disk(on_windows: bool) -> None:
"""Test loading embedder from disk with custom saved model."""
model = tiny_sentence_transformer()

Expand All @@ -41,7 +44,12 @@ def embedder(self, embedder_config: EmbedderConfig) -> Embedder:
"""Create an Embedder instance for testing."""
return Embedder(embedder_config)

def test_dump_load_cycle(self, embedder: Embedder, on_windows, embedder_config: EmbedderConfig): # noqa: ARG002
def test_dump_load_cycle(
self,
embedder: Embedder,
on_windows: bool,
embedder_config: EmbedderConfig, # noqa: ARG002
) -> None:
"""Test complete dump/load cycle preserves functionality."""
with tempfile.TemporaryDirectory(ignore_cleanup_errors=on_windows) as temp_dir:
temp_path = Path(temp_dir)
Expand All @@ -60,17 +68,28 @@ def test_dump_load_cycle(self, embedder: Embedder, on_windows, embedder_config:
loaded_embeddings = embedder_loaded.embed(test_utterances)
np.testing.assert_allclose(original_embeddings, loaded_embeddings, rtol=1e-3)

# Test configuration preservation (only for configs that have these attributes)
# Test configuration preservation (only for configs that have these attributes).
# The BaseEmbedderConfig union doesn't expose backend-specific fields; the hasattr
# checks are runtime guards, so cast to a concrete subclass with the attribute.
if hasattr(embedder.config, "model_name"):
assert embedder_loaded.config.model_name == embedder.config.model_name
loaded_named = cast("SentenceTransformerEmbeddingConfig", embedder_loaded.config)
original_named = cast("SentenceTransformerEmbeddingConfig", embedder.config)
assert loaded_named.model_name == original_named.model_name
if hasattr(embedder.config, "default_prompt"):
assert embedder_loaded.config.default_prompt == embedder.config.default_prompt
if hasattr(embedder.config, "batch_size"):
assert embedder_loaded.config.batch_size == embedder.config.batch_size

def test_load_with_config_override(self, embedder: Embedder, on_windows, embedder_config: EmbedderConfig): # noqa: ARG002
loaded_batched = cast("OpenaiEmbeddingConfig", embedder_loaded.config)
original_batched = cast("OpenaiEmbeddingConfig", embedder.config)
assert loaded_batched.batch_size == original_batched.batch_size

def test_load_with_config_override(
self,
embedder: Embedder,
on_windows: bool,
embedder_config: EmbedderConfig, # noqa: ARG002
) -> None:
"""Test loading with configuration override."""
from autointent.configs import HashingVectorizerEmbeddingConfig, OpenaiEmbeddingConfig
from autointent.configs import HashingVectorizerEmbeddingConfig

# Skip for HashingVectorizer as it doesn't support batch_size override
if isinstance(embedder.config, HashingVectorizerEmbeddingConfig):
Expand All @@ -83,6 +102,7 @@ def test_load_with_config_override(self, embedder: Embedder, on_windows, embedde
embedder.dump(temp_path)

# Create appropriate override config based on backend type
override_config: EmbedderConfig
if isinstance(embedder.config, SentenceTransformerEmbeddingConfig):
override_config = SentenceTransformerEmbeddingConfig(batch_size=16)
else:
Expand All @@ -92,12 +112,16 @@ def test_load_with_config_override(self, embedder: Embedder, on_windows, embedde
# Load with override
embedder_loaded = Embedder.load(temp_path, override_config)

# Verify override took effect
assert embedder_loaded.config.batch_size == 16
# Verify override took effect. embedder_loaded.config is the union
# BaseEmbedderConfig | ...; both SentenceTransformer and Openai
# subclasses carry batch_size/model_name, so cast for attribute access.
loaded_specific = cast("OpenaiEmbeddingConfig", embedder_loaded.config)
original_specific = cast("OpenaiEmbeddingConfig", embedder.config)
assert loaded_specific.batch_size == 16
# Verify original config preserved where not overridden
assert embedder_loaded.config.model_name == embedder.config.model_name
assert loaded_specific.model_name == original_specific.model_name

def test_similarity_preserved_after_load(self, embedder: Embedder, on_windows):
def test_similarity_preserved_after_load(self, embedder: Embedder, on_windows: bool) -> None:
"""Test that similarity function works correctly after dump/load."""
with tempfile.TemporaryDirectory(ignore_cleanup_errors=on_windows) as temp_dir:
temp_path = Path(temp_dir)
Expand All @@ -118,7 +142,7 @@ def test_similarity_preserved_after_load(self, embedder: Embedder, on_windows):
# Similarities should be the same
np.testing.assert_allclose(original_similarity, loaded_similarity, rtol=1e-3)

def test_multiple_dump_load_cycles(self, embedder: Embedder, on_windows):
def test_multiple_dump_load_cycles(self, embedder: Embedder, on_windows: bool) -> None:
"""Test multiple dump/load cycles maintain consistency."""
with tempfile.TemporaryDirectory(ignore_cleanup_errors=on_windows) as temp_dir:
temp_path = Path(temp_dir)
Expand Down
14 changes: 12 additions & 2 deletions tests/embedder/test_fine_tuned.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, cast

import numpy as np
import pytest

Expand All @@ -6,8 +10,12 @@
from autointent.context.data_handler import DataHandler
from tests.conftest import tiny_sentence_transformer_config

if TYPE_CHECKING:
from autointent import Dataset
from autointent.custom_types import ListOfLabels


def test_model_updates_after_training(dataset):
def test_model_updates_after_training(dataset: Dataset) -> None:
"""Test that model weights actually change after training"""
pytest.importorskip("accelerate", reason="Accelerate library is required for this test")

Expand Down Expand Up @@ -35,9 +43,11 @@ def test_model_updates_after_training(dataset):
param.data.detach().cpu().numpy().copy() for param in backend._model.parameters() if param.requires_grad
]

# data_handler.train_labels returns ListOfGenericLabels (may contain None for OOS);
# the test dataset has no OOS, so cast to the strict ListOfLabels for the typed API.
backend.train(
utterances=data_handler.train_utterances(0)[:1000],
labels=data_handler.train_labels(0)[:1000],
labels=cast("ListOfLabels", data_handler.train_labels(0)[:1000]),
config=train_config,
)

Expand Down
Loading
Loading