diff --git a/tests/experimental/test_kto_trainer.py b/tests/experimental/test_kto_trainer.py index 61aa47f30bb..510be623be5 100644 --- a/tests/experimental/test_kto_trainer.py +++ b/tests/experimental/test_kto_trainer.py @@ -1268,3 +1268,37 @@ def test_train_vlm_liger(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Param {n} is not updated" + + +@pytest.mark.slow +class TestKTOTrainerSlow(TrlTestCase): + # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. + # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs. + @pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token") + @require_vision + def test_train_vlm_gemma_3n(self): + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_unpaired_preference", split="train") + + training_args = KTOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + max_length=None, # for VLMs, truncating can remove image tokens, leading to errors + per_device_train_batch_size=1, # VLM training is memory intensive, reduce batch size to avoid OOM + model_init_kwargs={"dtype": "bfloat16"}, + report_to="none", + ) + trainer = KTOTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "model.audio_tower" in n or "model.embed_audio" in n: + # The audio embedding parameters are not updated because this dataset contains no audio data + continue + assert not torch.equal(param, new_param), f"Param {n} is not updated"