diff --git a/tests/generation/intents/test_description_generation.py b/tests/generation/intents/test_description_generation.py index 0a79b1da9..d34d5fd70 100644 --- a/tests/generation/intents/test_description_generation.py +++ b/tests/generation/intents/test_description_generation.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from collections import defaultdict +from typing import TYPE_CHECKING, cast from unittest.mock import AsyncMock, patch import pytest @@ -13,46 +16,49 @@ ) from autointent.schemas import Intent, Sample +if TYPE_CHECKING: + from autointent.generation import Generator + -def test_get_utterances_by_id_empty_input(): - utterances = [] +def test_get_utterances_by_id_empty_input() -> None: + utterances: list[Sample] = [] result = group_utterances_by_label(utterances) assert result == {} -def test_get_utterances_by_id_single_multiclass_utterance(): +def test_get_utterances_by_id_single_multiclass_utterance() -> None: samples = [Sample(utterance="Hello", label=1)] result = group_utterances_by_label(samples) assert result == {1: ["Hello"]} -def test_get_utterances_by_id_multiple_multiclass_same_label(): +def test_get_utterances_by_id_multiple_multiclass_same_label() -> None: samples = [Sample(utterance="Hello", label=1), Sample(utterance="Hi", label=1)] result = group_utterances_by_label(samples) assert result == {1: ["Hello", "Hi"]} -def test_get_utterances_by_id_single_multilabel_utterance(): +def test_get_utterances_by_id_single_multilabel_utterance() -> None: samples = [Sample(utterance="Good morning", label=[0, 1, 1, 0])] result = group_utterances_by_label(samples) expected_result = {1: ["Good morning"], 2: ["Good morning"]} assert result == expected_result -def test_get_utterances_by_id_multiple_multilabel_utterances(): +def test_get_utterances_by_id_multiple_multilabel_utterances() -> None: samples = [Sample(utterance="Good morning", label=[0, 1, 1, 0]), Sample(utterance="Good night", label=[0, 1, 0, 1])] result = group_utterances_by_label(samples) expected_result = {1: ["Good morning", "Good night"], 2: ["Good morning"], 3: ["Good night"]} assert result == expected_result -def test_get_utterances_by_id_oos_utterances(): +def test_get_utterances_by_id_oos_utterances() -> None: samples = [Sample(utterance="Unknown command", label=None), Sample(utterance="Hello", label=[0, 0, 1])] result = group_utterances_by_label(samples) assert result == {2: ["Hello"]} -def test_get_utterances_by_id_mixed_types(): +def test_get_utterances_by_id_mixed_types() -> None: samples = [ Sample(utterance="Hello", label=1), Sample(utterance="Good morning", label=[0, 1, 0, 1]), @@ -64,7 +70,7 @@ def test_get_utterances_by_id_mixed_types(): assert result == expected_result -def test_get_utterances_by_id_duplicate_texts_different_labels(): +def test_get_utterances_by_id_duplicate_texts_different_labels() -> None: samples = [Sample(utterance="Duplicate", label=1), Sample(utterance="Duplicate", label=2)] result = group_utterances_by_label(samples) expected_result = {1: ["Duplicate"], 2: ["Duplicate"]} @@ -72,7 +78,7 @@ def test_get_utterances_by_id_duplicate_texts_different_labels(): @pytest.mark.asyncio -async def test_create_intent_description_basic(): +async def test_create_intent_description_basic() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" @@ -83,7 +89,7 @@ async def test_create_intent_description_basic(): ) description = await create_intent_description( - client=client, + client=cast("Generator", client), intent_name="Greeting", utterances=utterances, prompt=prompt, @@ -94,7 +100,7 @@ async def test_create_intent_description_basic(): @pytest.mark.asyncio -async def test_create_intent_description_empty_intent_name(): +async def test_create_intent_description_empty_intent_name() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" @@ -105,7 +111,7 @@ async def test_create_intent_description_empty_intent_name(): ) description = await create_intent_description( - client=client, + client=cast("Generator", client), intent_name=None, utterances=utterances, prompt=prompt, @@ -116,18 +122,18 @@ async def test_create_intent_description_empty_intent_name(): @pytest.mark.asyncio -async def test_create_intent_description_empty_utterances_patterns(): +async def test_create_intent_description_empty_utterances_patterns() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" - utterances = [] + utterances: list[str] = [] prompt = PromptDescription( user_text="Describe intent {intent_name} with examples: {user_utterances}", ) description = await create_intent_description( - client=client, + client=cast("Generator", client), intent_name="Greeting", utterances=utterances, prompt=prompt, @@ -138,7 +144,7 @@ async def test_create_intent_description_empty_utterances_patterns(): @pytest.mark.asyncio -async def test_create_intent_description_large_utterances_patterns(): +async def test_create_intent_description_large_utterances_patterns() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" @@ -150,7 +156,7 @@ async def test_create_intent_description_large_utterances_patterns(): with patch("random.sample", side_effect=lambda x, k: x[:k]) as mock_sample: description = await create_intent_description( - client=client, + client=cast("Generator", client), intent_name="Greeting", utterances=utterances, prompt=prompt, @@ -162,7 +168,7 @@ async def test_create_intent_description_large_utterances_patterns(): @pytest.mark.asyncio -async def test_generate_intent_descriptions_basic(): +async def test_generate_intent_descriptions_basic() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" @@ -176,7 +182,7 @@ async def test_generate_intent_descriptions_basic(): user_text="Describe intent {intent_name} with examples: {user_utterances}", ) updated_intents = await generate_intent_descriptions( - client=client, + client=cast("Generator", client), intent_utterances=intent_utterances, intents=intents, prompt=prompt, @@ -187,7 +193,7 @@ async def test_generate_intent_descriptions_basic(): @pytest.mark.asyncio -async def test_generate_intent_descriptions_skip_existing_descriptions(): +async def test_generate_intent_descriptions_skip_existing_descriptions() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" @@ -207,7 +213,7 @@ async def test_generate_intent_descriptions_skip_existing_descriptions(): user_text="Describe intent {intent_name} with examples: {user_utterances}", ) updated_intents = await generate_intent_descriptions( - client=client, + client=cast("Generator", client), intent_utterances=intent_utterances, intents=intents, prompt=prompt, @@ -219,12 +225,12 @@ async def test_generate_intent_descriptions_skip_existing_descriptions(): @pytest.mark.asyncio -async def test_generate_intent_descriptions_empty_utterances_patterns(): +async def test_generate_intent_descriptions_empty_utterances_patterns() -> None: client = AsyncMock() mock_create = client.get_chat_completion_async mock_create.return_value = "Generated description" - intent_utterances = {} # No utterances for any intent + intent_utterances: dict[int, list[str]] = {} # No utterances for any intent intents = [ Intent(id=1, name="Greeting", description=None, regex_full_match=[], regex_partial_match=[]), ] @@ -232,7 +238,7 @@ async def test_generate_intent_descriptions_empty_utterances_patterns(): user_text="Describe intent {intent_name} with examples: {user_utterances}", ) updated_intents = await generate_intent_descriptions( - client=client, + client=cast("Generator", client), intent_utterances=intent_utterances, intents=intents, prompt=prompt, @@ -254,7 +260,7 @@ async def test_generate_intent_descriptions_empty_utterances_patterns(): ) -def test_enhance_dataset_with_descriptions_basic(): +def test_enhance_dataset_with_descriptions_basic() -> None: client = AsyncMock() with patch( "autointent.generation.intents._description_generation.generate_intent_descriptions", @@ -282,7 +288,7 @@ def test_enhance_dataset_with_descriptions_basic(): ) enhanced_dataset = generate_descriptions( dataset=dataset, - client=client, + client=cast("Generator", client), prompt=prompt, ) expected_intent_utterances = defaultdict(list, {0: ["Hello"], 1: ["Goodbye"]}) @@ -300,7 +306,7 @@ def test_enhance_dataset_with_descriptions_basic(): ) -def test_enhance_dataset_with_existing_descriptions(): +def test_enhance_dataset_with_existing_descriptions() -> None: client = AsyncMock() with patch( "autointent.generation.intents._description_generation.generate_intent_descriptions", @@ -328,7 +334,7 @@ def test_enhance_dataset_with_existing_descriptions(): ) enhanced_dataset = generate_descriptions( dataset=dataset, - client=client, + client=cast("Generator", client), prompt=prompt, ) expected_intent_utterances = defaultdict(list, {0: ["Hello"], 1: ["Goodbye"]}) diff --git a/tests/generation/structured_output/test_basics.py b/tests/generation/structured_output/test_basics.py index 07d86932e..e01fc6951 100644 --- a/tests/generation/structured_output/test_basics.py +++ b/tests/generation/structured_output/test_basics.py @@ -1,7 +1,9 @@ """Tests for structured output functionality.""" +from __future__ import annotations + import json -from typing import Literal +from typing import TYPE_CHECKING, Literal import httpx import pytest @@ -10,6 +12,9 @@ from autointent.generation import Generator from autointent.generation.chat_templates import Role +if TYPE_CHECKING: + from respx.router import MockRouter + class Person(BaseModel): reasoning: str = Field(description="Some preliminary reasoning to plan fields' values") @@ -57,22 +62,24 @@ def _chat_completion_response(content: str) -> httpx.Response: @pytest.fixture -def generator(respx_openai): +def generator(respx_openai: MockRouter) -> Generator: """Create a generator instance for testing.""" - return Generator(max_tokens=1000, use_cache=False) + # reason: Generator.__init__ types **generation_params as dict[str, Any] (src bug: + # should be Any). int kwargs are valid at runtime; flagged for Phase C. + return Generator(max_tokens=1000, use_cache=False) # type: ignore[arg-type] class TestStructuredOutput: """Test structured output functionality for different backends.""" - def test_basic_chat_completion(self, generator, respx_openai): + def test_basic_chat_completion(self, generator: Generator, respx_openai: MockRouter) -> None: respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response("hi! here's a joke")) response = generator.get_chat_completion(messages=[{"role": Role.USER, "content": "hi! tell me a joke"}]) assert isinstance(response, str) assert len(response) > 0 @pytest.mark.asyncio - async def test_async_chat_completion(self, generator, respx_openai): + async def test_async_chat_completion(self, generator: Generator, respx_openai: MockRouter) -> None: respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response("hi! here's a joke")) response = await generator.get_chat_completion_async( messages=[{"role": Role.USER, "content": "hi! tell me a joke"}] @@ -80,7 +87,7 @@ async def test_async_chat_completion(self, generator, respx_openai): assert isinstance(response, str) assert len(response) > 0 - def test_structured_output(self, generator, respx_openai): + def test_structured_output(self, generator: Generator, respx_openai: MockRouter) -> None: respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response(VALID_PERSON_JSON)) result = generator.get_structured_output_sync( messages=[{"role": Role.USER, "content": "Create a person"}], @@ -90,7 +97,7 @@ def test_structured_output(self, generator, respx_openai): assert isinstance(result, Person) @pytest.mark.asyncio - async def test_structured_output_async(self, generator, respx_openai): + async def test_structured_output_async(self, generator: Generator, respx_openai: MockRouter) -> None: respx_openai.post("/v1/chat/completions").mock(return_value=_chat_completion_response(VALID_PERSON_JSON)) result = await generator.get_structured_output_async( messages=[{"role": Role.USER, "content": "Create a person"}], diff --git a/tests/generation/structured_output/test_caching.py b/tests/generation/structured_output/test_caching.py index 5d93b87dc..80613727b 100644 --- a/tests/generation/structured_output/test_caching.py +++ b/tests/generation/structured_output/test_caching.py @@ -1,6 +1,9 @@ """Tests for Generator cache semantics.""" +from __future__ import annotations + import json +from typing import TYPE_CHECKING import httpx import pytest @@ -9,9 +12,16 @@ from autointent.generation import Generator from autointent.generation.chat_templates import Role +if TYPE_CHECKING: + from pathlib import Path + + from respx.router import MockRouter + + from autointent.generation.chat_templates import Message + @pytest.fixture(autouse=True) -def _isolated_cache(tmp_path, monkeypatch): +def _isolated_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: """Redirect the structured-output disk cache to a fresh tmp dir each test.""" monkeypatch.setattr("autointent.generation._cache.user_cache_dir", lambda *_: str(tmp_path)) @@ -37,19 +47,26 @@ def _resp(name: str, value: int) -> httpx.Response: @pytest.fixture -def generator_with_cache(respx_openai): - return Generator(max_tokens=1000, use_cache=True, temperature=2) +def generator_with_cache(respx_openai: MockRouter) -> Generator: + # reason: Generator.__init__ types **generation_params as dict[str, Any] (src bug: + # should be Any). int kwargs are valid at runtime; flagged for Phase C. + return Generator(max_tokens=1000, use_cache=True, temperature=2) # type: ignore[arg-type] @pytest.fixture -def generator_without_cache(respx_openai): - return Generator(max_tokens=1000, use_cache=False, temperature=2) +def generator_without_cache(respx_openai: MockRouter) -> Generator: + # reason: same Generator src bug as above. + return Generator(max_tokens=1000, use_cache=False, temperature=2) # type: ignore[arg-type] @pytest.mark.asyncio -async def test_cache_hit(generator_with_cache, generator_without_cache, respx_openai): - messages = [{"role": Role.USER, "content": "Create a random simple model"}] - different_messages = [{"role": Role.USER, "content": "Create a person named John with value 333"}] +async def test_cache_hit( + generator_with_cache: Generator, + generator_without_cache: Generator, + respx_openai: MockRouter, +) -> None: + messages: list[Message] = [{"role": Role.USER, "content": "Create a random simple model"}] + different_messages: list[Message] = [{"role": Role.USER, "content": "Create a person named John with value 333"}] route = respx_openai.post("/v1/chat/completions").mock( side_effect=[ diff --git a/tests/generation/structured_output/test_retries.py b/tests/generation/structured_output/test_retries.py index 685bcb5a9..bff1dbe87 100644 --- a/tests/generation/structured_output/test_retries.py +++ b/tests/generation/structured_output/test_retries.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import Literal +from typing import TYPE_CHECKING, Literal import httpx import pytest @@ -12,6 +12,9 @@ from autointent.generation import Generator, RetriesExceededError from autointent.generation.chat_templates import Role +if TYPE_CHECKING: + from respx.router import MockRouter + class Person(BaseModel): reasoning: str = Field(description="Some preliminary reasoning to plan fields' values") @@ -75,12 +78,16 @@ def _resp(content: str) -> httpx.Response: @pytest.fixture -def generator(respx_openai): - return Generator(max_tokens=1000, use_cache=False) +def generator(respx_openai: MockRouter) -> Generator: + # reason: Generator.__init__ types **generation_params as dict[str, Any] (src bug: + # should be Any). int kwargs are valid at runtime; flagged for Phase C. + return Generator(max_tokens=1000, use_cache=False) # type: ignore[arg-type] class TestStructuredOutput: - def test_structured_output_sync_success_with_enough_retries(self, generator, respx_openai): + def test_structured_output_sync_success_with_enough_retries( + self, generator: Generator, respx_openai: MockRouter + ) -> None: respx_openai.post("/v1/chat/completions").mock( side_effect=[_resp(INVALID_PERSON_JSON), _resp(INVALID_PERSON_JSON), _resp(VALID_PERSON_JSON)] ) @@ -92,7 +99,9 @@ def test_structured_output_sync_success_with_enough_retries(self, generator, res assert isinstance(result, Person) @pytest.mark.asyncio - async def test_structured_output_async_success_with_enough_retries(self, generator, respx_openai): + async def test_structured_output_async_success_with_enough_retries( + self, generator: Generator, respx_openai: MockRouter + ) -> None: respx_openai.post("/v1/chat/completions").mock( side_effect=[_resp(INVALID_PERSON_JSON), _resp(INVALID_PERSON_JSON), _resp(VALID_PERSON_JSON)] ) @@ -103,7 +112,9 @@ async def test_structured_output_async_success_with_enough_retries(self, generat ) assert isinstance(result, Person) - def test_structured_output_sync_failure_with_insufficient_retries(self, generator, respx_openai): + def test_structured_output_sync_failure_with_insufficient_retries( + self, generator: Generator, respx_openai: MockRouter + ) -> None: respx_openai.post("/v1/chat/completions").mock(return_value=_resp(INVALID_PERSON_JSON)) with pytest.raises(RetriesExceededError): generator.get_structured_output_sync( @@ -113,7 +124,9 @@ def test_structured_output_sync_failure_with_insufficient_retries(self, generato ) @pytest.mark.asyncio - async def test_structured_output_async_failure_with_insufficient_retries(self, generator, respx_openai): + async def test_structured_output_async_failure_with_insufficient_retries( + self, generator: Generator, respx_openai: MockRouter + ) -> None: respx_openai.post("/v1/chat/completions").mock(return_value=_resp(INVALID_PERSON_JSON)) with pytest.raises(RetriesExceededError): await generator.get_structured_output_async( diff --git a/tests/generation/utterances/test_adversarial.py b/tests/generation/utterances/test_adversarial.py index 8fc33f33e..3e2ef9899 100644 --- a/tests/generation/utterances/test_adversarial.py +++ b/tests/generation/utterances/test_adversarial.py @@ -8,7 +8,7 @@ @pytest.fixture -def dataset(): +def dataset() -> Dataset: return Dataset.from_dict( { "intents": [ @@ -24,7 +24,7 @@ def dataset(): ) -def test_human_utterance_generator_sync(dataset): +def test_human_utterance_generator_sync(dataset: Dataset) -> None: mock_llm = Mock() mock_llm.get_chat_completion.return_value = "Human-like utterance" @@ -44,7 +44,7 @@ def test_human_utterance_generator_sync(dataset): assert all("label" in sample.dict() for sample in new_samples) -def test_human_utterance_generator_async(dataset): +def test_human_utterance_generator_async(dataset: Dataset) -> None: mock_llm = AsyncMock() mock_llm.get_chat_completion_async.return_value = "Human-like utterance" @@ -63,7 +63,7 @@ def test_human_utterance_generator_async(dataset): assert all("label" in sample.dict() for sample in new_samples) -def test_human_utterance_generator_respects_critic(dataset): +def test_human_utterance_generator_respects_critic(dataset: Dataset) -> None: mock_llm = Mock() mock_llm.get_chat_completion.return_value = "Generated utterance" diff --git a/tests/generation/utterances/test_balancer.py b/tests/generation/utterances/test_balancer.py index f25792496..1fcda33e9 100644 --- a/tests/generation/utterances/test_balancer.py +++ b/tests/generation/utterances/test_balancer.py @@ -12,7 +12,7 @@ @pytest.fixture -def mock_generator(): +def mock_generator() -> Mock: generator = Mock(spec=Generator) generator.get_chat_completion.return_value = "test_utterance" generator.get_chat_completion_async = AsyncMock(return_value="test_utterance") @@ -20,12 +20,12 @@ def mock_generator(): @pytest.fixture -def mock_prompt_maker(): +def mock_prompt_maker() -> Mock: return Mock(return_value=[Mock()]) @pytest.fixture -def unbalanced_dataset(): +def unbalanced_dataset() -> Dataset: return Dataset.from_dict( { "intents": [{"id": 0, "name": "A"}, {"id": 1, "name": "B"}], @@ -38,7 +38,7 @@ def unbalanced_dataset(): ) -def test_balancer(unbalanced_dataset, mock_generator, mock_prompt_maker): +def test_balancer(unbalanced_dataset: Dataset, mock_generator: Mock, mock_prompt_maker: Mock) -> None: balancer = DatasetBalancer(generator=mock_generator, prompt_maker=mock_prompt_maker) logger.info("Before balancing:") for sample in unbalanced_dataset[Split.TRAIN]: diff --git a/tests/generation/utterances/test_basic_synthesizer.py b/tests/generation/utterances/test_basic_synthesizer.py index f89ebac82..7b96db059 100644 --- a/tests/generation/utterances/test_basic_synthesizer.py +++ b/tests/generation/utterances/test_basic_synthesizer.py @@ -1,10 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING from unittest.mock import AsyncMock, Mock from autointent.generation.chat_templates import EnglishSynthesizerTemplate from autointent.generation.utterances import UtteranceGenerator +if TYPE_CHECKING: + from autointent import Dataset + -def has_unfilled_fields(template): +def has_unfilled_fields(template: str) -> bool: try: # Attempt to format the string with empty values template.format(**{}) # noqa: PIE804 @@ -13,22 +19,24 @@ def has_unfilled_fields(template): return True # Unfilled fields detected -def test_default_chat_template(dataset): +def test_default_chat_template(dataset: Dataset) -> None: template = EnglishSynthesizerTemplate(dataset, split="train_0") prompt = template(dataset.intents[0], n_examples=1) for msg in prompt: assert not has_unfilled_fields(msg["content"]) - assert "extra_instructions" not in prompt + # reason: legacy sentinel check; prompt is list[Message] (TypedDict), so a str is + # never in it. Preserving original test behavior per Phase B no-semantic-change rule. + assert "extra_instructions" not in prompt # type: ignore[comparison-overlap] -def test_extra_instructions(dataset): +def test_extra_instructions(dataset: Dataset) -> None: template = EnglishSynthesizerTemplate(dataset, split="train_0", extra_instructions="football") prompt = template(dataset.intents[0], n_examples=1)[0]["content"] assert "extra_instructions" not in prompt assert "football" in prompt -def test_on_dataset(dataset): +def test_on_dataset(dataset: Dataset) -> None: mock_llm = Mock() mock_llm.get_chat_completion.return_value = "1. LLM answer" @@ -53,7 +61,7 @@ def test_on_dataset(dataset): assert len(new_samples) == len(dataset.intents) -def test_on_dataset_async(dataset): +def test_on_dataset_async(dataset: Dataset) -> None: mock_llm = AsyncMock() mock_llm.get_chat_completion_async.return_value = "1. LLM answer" @@ -78,7 +86,7 @@ def test_on_dataset_async(dataset): assert len(new_samples) == len(dataset.intents) -def test_on_dataset_async_with_batch_size(dataset): +def test_on_dataset_async_with_batch_size(dataset: Dataset) -> None: mock_llm = AsyncMock() mock_llm.get_chat_completion_async.return_value = "1. LLM answer" diff --git a/tests/generation/utterances/test_evolver.py b/tests/generation/utterances/test_evolver.py index b23dc1798..57575ad4e 100644 --- a/tests/generation/utterances/test_evolver.py +++ b/tests/generation/utterances/test_evolver.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING from unittest.mock import AsyncMock, Mock import pytest @@ -5,14 +8,17 @@ from autointent.generation.chat_templates import AbstractEvolution from autointent.generation.utterances import UtteranceEvolver +if TYPE_CHECKING: + from autointent import Dataset + -def test_default_chat_template(dataset): +def test_default_chat_template(dataset: Dataset) -> None: template = AbstractEvolution() prompt = template("some utterance", dataset.intents[0]) assert "some utterance" in prompt[-1]["content"] -def test_on_dataset(dataset): +def test_on_dataset(dataset: Dataset) -> None: mock_llm = Mock() mock_llm.get_chat_completion.return_value = "LLM answer" @@ -46,7 +52,7 @@ def test_on_dataset(dataset): assert set(new_samples.column_names) == set(dataset[split_name].column_names) -def test_on_dataset_evolver_async(dataset): +def test_on_dataset_evolver_async(dataset: Dataset) -> None: mock_llm = AsyncMock() mock_llm.get_chat_completion_async.return_value = "LLM answer" @@ -77,7 +83,7 @@ def test_on_dataset_evolver_async(dataset): ) -def test_on_dataset_evolver_async_with_batch_size(dataset): +def test_on_dataset_evolver_async_with_batch_size(dataset: Dataset) -> None: mock_llm = AsyncMock() mock_llm.get_chat_completion_async.return_value = "LLM answer" diff --git a/tests/generation/utterances/test_generator.py b/tests/generation/utterances/test_generator.py index cc6c54011..12e5bccd7 100644 --- a/tests/generation/utterances/test_generator.py +++ b/tests/generation/utterances/test_generator.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -5,29 +8,32 @@ from autointent.generation import Generator from autointent.generation.chat_templates import Message +if TYPE_CHECKING: + from collections.abc import Iterator + pytest.importorskip("openai", reason="OpenAI library is required") @pytest.fixture(autouse=True) -def set_env_vars(monkeypatch): +def set_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("OPENAI_BASE_URL", "https://api.openai.com/v1") monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") monkeypatch.setenv("OPENAI_MODEL_NAME", "gpt-3.5-turbo") @pytest.fixture -def mock_openai_client(): +def mock_openai_client() -> Iterator[MagicMock]: with patch("openai.OpenAI") as mock_client: yield mock_client -def test_generator_initialization(mock_openai_client): +def test_generator_initialization(mock_openai_client: MagicMock) -> None: generator = Generator() assert generator.client == mock_openai_client.return_value assert generator.model_name == "gpt-3.5-turbo" -def test_get_chat_completion(mock_openai_client): +def test_get_chat_completion(mock_openai_client: MagicMock) -> None: mock_response = MagicMock() mock_response.choices = [MagicMock(message=MagicMock(content="Test response"))] mock_openai_client.return_value.chat.completions.create.return_value = mock_response @@ -41,7 +47,7 @@ def test_get_chat_completion(mock_openai_client): @pytest.mark.asyncio -async def test_get_chat_completion_async(): +async def test_get_chat_completion_async() -> None: test_messages = [Message(role="user", content="Hello, how are you?")] mock_response = MagicMock()