Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions tests/generation/intents/test_description_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

from collections import defaultdict
from typing import TYPE_CHECKING, cast
from unittest.mock import AsyncMock, patch

import pytest
Expand All @@ -13,46 +16,49 @@
)
from autointent.schemas import Intent, Sample

if TYPE_CHECKING:
from autointent.generation import Generator


def test_get_utterances_by_id_empty_input():
utterances = []
def test_get_utterances_by_id_empty_input() -> None:
utterances: list[Sample] = []
result = group_utterances_by_label(utterances)
assert result == {}


def test_get_utterances_by_id_single_multiclass_utterance():
def test_get_utterances_by_id_single_multiclass_utterance() -> None:
samples = [Sample(utterance="Hello", label=1)]
result = group_utterances_by_label(samples)
assert result == {1: ["Hello"]}


def test_get_utterances_by_id_multiple_multiclass_same_label():
def test_get_utterances_by_id_multiple_multiclass_same_label() -> None:
samples = [Sample(utterance="Hello", label=1), Sample(utterance="Hi", label=1)]
result = group_utterances_by_label(samples)
assert result == {1: ["Hello", "Hi"]}


def test_get_utterances_by_id_single_multilabel_utterance():
def test_get_utterances_by_id_single_multilabel_utterance() -> None:
samples = [Sample(utterance="Good morning", label=[0, 1, 1, 0])]
result = group_utterances_by_label(samples)
expected_result = {1: ["Good morning"], 2: ["Good morning"]}
assert result == expected_result


def test_get_utterances_by_id_multiple_multilabel_utterances():
def test_get_utterances_by_id_multiple_multilabel_utterances() -> None:
samples = [Sample(utterance="Good morning", label=[0, 1, 1, 0]), Sample(utterance="Good night", label=[0, 1, 0, 1])]
result = group_utterances_by_label(samples)
expected_result = {1: ["Good morning", "Good night"], 2: ["Good morning"], 3: ["Good night"]}
assert result == expected_result


def test_get_utterances_by_id_oos_utterances():
def test_get_utterances_by_id_oos_utterances() -> None:
samples = [Sample(utterance="Unknown command", label=None), Sample(utterance="Hello", label=[0, 0, 1])]
result = group_utterances_by_label(samples)
assert result == {2: ["Hello"]}


def test_get_utterances_by_id_mixed_types():
def test_get_utterances_by_id_mixed_types() -> None:
samples = [
Sample(utterance="Hello", label=1),
Sample(utterance="Good morning", label=[0, 1, 0, 1]),
Expand All @@ -64,15 +70,15 @@ def test_get_utterances_by_id_mixed_types():
assert result == expected_result


def test_get_utterances_by_id_duplicate_texts_different_labels():
def test_get_utterances_by_id_duplicate_texts_different_labels() -> None:
samples = [Sample(utterance="Duplicate", label=1), Sample(utterance="Duplicate", label=2)]
result = group_utterances_by_label(samples)
expected_result = {1: ["Duplicate"], 2: ["Duplicate"]}
assert result == expected_result


@pytest.mark.asyncio
async def test_create_intent_description_basic():
async def test_create_intent_description_basic() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"
Expand All @@ -83,7 +89,7 @@ async def test_create_intent_description_basic():
)

description = await create_intent_description(
client=client,
client=cast("Generator", client),
intent_name="Greeting",
utterances=utterances,
prompt=prompt,
Expand All @@ -94,7 +100,7 @@ async def test_create_intent_description_basic():


@pytest.mark.asyncio
async def test_create_intent_description_empty_intent_name():
async def test_create_intent_description_empty_intent_name() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"
Expand All @@ -105,7 +111,7 @@ async def test_create_intent_description_empty_intent_name():
)

description = await create_intent_description(
client=client,
client=cast("Generator", client),
intent_name=None,
utterances=utterances,
prompt=prompt,
Expand All @@ -116,18 +122,18 @@ async def test_create_intent_description_empty_intent_name():


@pytest.mark.asyncio
async def test_create_intent_description_empty_utterances_patterns():
async def test_create_intent_description_empty_utterances_patterns() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"

utterances = []
utterances: list[str] = []
prompt = PromptDescription(
user_text="Describe intent {intent_name} with examples: {user_utterances}",
)

description = await create_intent_description(
client=client,
client=cast("Generator", client),
intent_name="Greeting",
utterances=utterances,
prompt=prompt,
Expand All @@ -138,7 +144,7 @@ async def test_create_intent_description_empty_utterances_patterns():


@pytest.mark.asyncio
async def test_create_intent_description_large_utterances_patterns():
async def test_create_intent_description_large_utterances_patterns() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"
Expand All @@ -150,7 +156,7 @@ async def test_create_intent_description_large_utterances_patterns():

with patch("random.sample", side_effect=lambda x, k: x[:k]) as mock_sample:
description = await create_intent_description(
client=client,
client=cast("Generator", client),
intent_name="Greeting",
utterances=utterances,
prompt=prompt,
Expand All @@ -162,7 +168,7 @@ async def test_create_intent_description_large_utterances_patterns():


@pytest.mark.asyncio
async def test_generate_intent_descriptions_basic():
async def test_generate_intent_descriptions_basic() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"
Expand All @@ -176,7 +182,7 @@ async def test_generate_intent_descriptions_basic():
user_text="Describe intent {intent_name} with examples: {user_utterances}",
)
updated_intents = await generate_intent_descriptions(
client=client,
client=cast("Generator", client),
intent_utterances=intent_utterances,
intents=intents,
prompt=prompt,
Expand All @@ -187,7 +193,7 @@ async def test_generate_intent_descriptions_basic():


@pytest.mark.asyncio
async def test_generate_intent_descriptions_skip_existing_descriptions():
async def test_generate_intent_descriptions_skip_existing_descriptions() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"
Expand All @@ -207,7 +213,7 @@ async def test_generate_intent_descriptions_skip_existing_descriptions():
user_text="Describe intent {intent_name} with examples: {user_utterances}",
)
updated_intents = await generate_intent_descriptions(
client=client,
client=cast("Generator", client),
intent_utterances=intent_utterances,
intents=intents,
prompt=prompt,
Expand All @@ -219,20 +225,20 @@ async def test_generate_intent_descriptions_skip_existing_descriptions():


@pytest.mark.asyncio
async def test_generate_intent_descriptions_empty_utterances_patterns():
async def test_generate_intent_descriptions_empty_utterances_patterns() -> None:
client = AsyncMock()
mock_create = client.get_chat_completion_async
mock_create.return_value = "Generated description"

intent_utterances = {} # No utterances for any intent
intent_utterances: dict[int, list[str]] = {} # No utterances for any intent
intents = [
Intent(id=1, name="Greeting", description=None, regex_full_match=[], regex_partial_match=[]),
]
prompt = PromptDescription(
user_text="Describe intent {intent_name} with examples: {user_utterances}",
)
updated_intents = await generate_intent_descriptions(
client=client,
client=cast("Generator", client),
intent_utterances=intent_utterances,
intents=intents,
prompt=prompt,
Expand All @@ -254,7 +260,7 @@ async def test_generate_intent_descriptions_empty_utterances_patterns():
)


def test_enhance_dataset_with_descriptions_basic():
def test_enhance_dataset_with_descriptions_basic() -> None:
client = AsyncMock()
with patch(
"autointent.generation.intents._description_generation.generate_intent_descriptions",
Expand Down Expand Up @@ -282,7 +288,7 @@ def test_enhance_dataset_with_descriptions_basic():
)
enhanced_dataset = generate_descriptions(
dataset=dataset,
client=client,
client=cast("Generator", client),
prompt=prompt,
)
expected_intent_utterances = defaultdict(list, {0: ["Hello"], 1: ["Goodbye"]})
Expand All @@ -300,7 +306,7 @@ def test_enhance_dataset_with_descriptions_basic():
)


def test_enhance_dataset_with_existing_descriptions():
def test_enhance_dataset_with_existing_descriptions() -> None:
client = AsyncMock()
with patch(
"autointent.generation.intents._description_generation.generate_intent_descriptions",
Expand Down Expand Up @@ -328,7 +334,7 @@ def test_enhance_dataset_with_existing_descriptions():
)
enhanced_dataset = generate_descriptions(
dataset=dataset,
client=client,
client=cast("Generator", client),
prompt=prompt,
)
expected_intent_utterances = defaultdict(list, {0: ["Hello"], 1: ["Goodbye"]})
Expand Down
21 changes: 14 additions & 7 deletions tests/generation/structured_output/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Tests for structured output functionality."""

from __future__ import annotations

import json
from typing import Literal
from typing import TYPE_CHECKING, Literal

import httpx
import pytest
Expand All @@ -10,6 +12,9 @@
from autointent.generation import Generator
from autointent.generation.chat_templates import Role

if TYPE_CHECKING:
from respx.router import MockRouter


class Person(BaseModel):
reasoning: str = Field(description="Some preliminary reasoning to plan fields' values")
Expand Down Expand Up @@ -57,30 +62,32 @@ def _chat_completion_response(content: str) -> httpx.Response:


@pytest.fixture
def generator(respx_openai):
def generator(respx_openai: MockRouter) -> Generator:
"""Create a generator instance for testing."""
return Generator(max_tokens=1000, use_cache=False)
# reason: Generator.__init__ types **generation_params as dict[str, Any] (src bug:
# should be Any). int kwargs are valid at runtime; flagged for Phase C.
return Generator(max_tokens=1000, use_cache=False) # type: ignore[arg-type]


class TestStructuredOutput:
"""Test structured output functionality for different backends."""

def test_basic_chat_completion(self, generator, respx_openai):
def test_basic_chat_completion(self, generator: Generator, respx_openai: MockRouter) -> None:
respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response("hi! here's a joke"))
response = generator.get_chat_completion(messages=[{"role": Role.USER, "content": "hi! tell me a joke"}])
assert isinstance(response, str)
assert len(response) > 0

@pytest.mark.asyncio
async def test_async_chat_completion(self, generator, respx_openai):
async def test_async_chat_completion(self, generator: Generator, respx_openai: MockRouter) -> None:
respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response("hi! here's a joke"))
response = await generator.get_chat_completion_async(
messages=[{"role": Role.USER, "content": "hi! tell me a joke"}]
)
assert isinstance(response, str)
assert len(response) > 0

def test_structured_output(self, generator, respx_openai):
def test_structured_output(self, generator: Generator, respx_openai: MockRouter) -> None:
respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response(VALID_PERSON_JSON))
result = generator.get_structured_output_sync(
messages=[{"role": Role.USER, "content": "Create a person"}],
Expand All @@ -90,7 +97,7 @@ def test_structured_output(self, generator, respx_openai):
assert isinstance(result, Person)

@pytest.mark.asyncio
async def test_structured_output_async(self, generator, respx_openai):
async def test_structured_output_async(self, generator: Generator, respx_openai: MockRouter) -> None:
respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response(VALID_PERSON_JSON))
result = await generator.get_structured_output_async(
messages=[{"role": Role.USER, "content": "Create a person"}],
Expand Down
33 changes: 25 additions & 8 deletions tests/generation/structured_output/test_caching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Tests for Generator cache semantics."""

from __future__ import annotations

import json
from typing import TYPE_CHECKING

import httpx
import pytest
Expand All @@ -9,9 +12,16 @@
from autointent.generation import Generator
from autointent.generation.chat_templates import Role

if TYPE_CHECKING:
from pathlib import Path

from respx.router import MockRouter

from autointent.generation.chat_templates import Message


@pytest.fixture(autouse=True)
def _isolated_cache(tmp_path, monkeypatch):
def _isolated_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Redirect the structured-output disk cache to a fresh tmp dir each test."""
monkeypatch.setattr("autointent.generation._cache.user_cache_dir", lambda *_: str(tmp_path))

Expand All @@ -37,19 +47,26 @@ def _resp(name: str, value: int) -> httpx.Response:


@pytest.fixture
def generator_with_cache(respx_openai):
return Generator(max_tokens=1000, use_cache=True, temperature=2)
def generator_with_cache(respx_openai: MockRouter) -> Generator:
# reason: Generator.__init__ types **generation_params as dict[str, Any] (src bug:
# should be Any). int kwargs are valid at runtime; flagged for Phase C.
return Generator(max_tokens=1000, use_cache=True, temperature=2) # type: ignore[arg-type]


@pytest.fixture
def generator_without_cache(respx_openai):
return Generator(max_tokens=1000, use_cache=False, temperature=2)
def generator_without_cache(respx_openai: MockRouter) -> Generator:
# reason: same Generator src bug as above.
return Generator(max_tokens=1000, use_cache=False, temperature=2) # type: ignore[arg-type]


@pytest.mark.asyncio
async def test_cache_hit(generator_with_cache, generator_without_cache, respx_openai):
messages = [{"role": Role.USER, "content": "Create a random simple model"}]
different_messages = [{"role": Role.USER, "content": "Create a person named John with value 333"}]
async def test_cache_hit(
generator_with_cache: Generator,
generator_without_cache: Generator,
respx_openai: MockRouter,
) -> None:
messages: list[Message] = [{"role": Role.USER, "content": "Create a random simple model"}]
different_messages: list[Message] = [{"role": Role.USER, "content": "Create a person named John with value 333"}]

route = respx_openai.post("/v1/chat/completions").mock(
side_effect=[
Expand Down
Loading
Loading