diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 6bc04d9664..cc1cdceed2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -264,6 +264,8 @@ async def _completions_create( top_p=model_settings.get('top_p', 1), timeout_ms=self._get_timeout_ms(model_settings.get('timeout')), random_seed=model_settings.get('seed', UNSET), + presence_penalty=model_settings.get('presence_penalty'), + frequency_penalty=model_settings.get('frequency_penalty'), stop=model_settings.get('stop_sequences', None), http_headers={'User-Agent': get_user_agent()}, ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 80f3f0eedf..4f18ddcbb3 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -91,6 +91,7 @@ class MockMistralAI: completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None stream: Sequence[MockCompletionEvent] | Sequence[Sequence[MockCompletionEvent]] | None = None index: int = 0 + last_kwargs: dict[str, Any] | None = None @cached_property def sdk_configuration(self) -> MockSdkConfiguration: @@ -120,6 +121,7 @@ def create_stream_mock( async def chat_completions_create( # pragma: lax no cover self, *_args: Any, stream: bool = False, **_kwargs: Any ) -> MistralChatCompletionResponse | MockAsyncStream[MockCompletionEvent]: + self.last_kwargs = _kwargs if stream or self.stream: assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided' if isinstance(self.stream[0], list): @@ -281,6 +283,22 @@ async def test_multiple_completions(allow_model_requests: None): ) +async def test_non_streaming_forwards_penalties(allow_model_requests: None): + completion = completion_message(MistralAssistantMessage(content='hello')) + mock_client = MockMistralAI.create_mock(completion) + model = MistralModel('mistral-large-latest', provider=MistralProvider(mistral_client=mock_client)) + agent = Agent(model=model) + + await agent.run( + 'hello', + model_settings={'presence_penalty': 0.5, 'frequency_penalty': 0.25}, + ) + + assert mock_client.last_kwargs is not None + assert mock_client.last_kwargs['presence_penalty'] == 0.5 + assert mock_client.last_kwargs['frequency_penalty'] == 0.25 + + async def test_three_completions(allow_model_requests: None): completions = [ completion_message(