Align KTO with DPO: Add TestKTOTrainerSlow with test_train_vlm_gemma_3n#6162
Align KTO with DPO: Add TestKTOTrainerSlow with test_train_vlm_gemma_3n#6162albertvillanova wants to merge 3 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 8bf9334. Configure here.
| model_init_kwargs={"dtype": "bfloat16"}, | ||
| report_to="none", | ||
| ) | ||
| trainer = KTOTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset) |
There was a problem hiding this comment.
Batch size one breaks KTOTrainer
Medium Severity
The slow Gemma VLM test sets per_device_train_batch_size=1 while using the default KTO loss, which keeps the KL term enabled. KTOTrainer raises during initialization when KL is on and the actual batch size is at most one, so the test fails as soon as the skip is removed—not during training.
Reviewed by Cursor Bugbot for commit 8bf9334. Configure here.


Align KTO with DPO: Add TestKTOTrainerSlow with test_train_vlm_gemma_3n.
Part of:
This PR adds a new slow-running test class to improve coverage for vision-language model (VLM) training, specifically targeting the
google/gemma-3n-E2B-itmodel. The new test is marked as slow and skipped by default due to external requirements, ensuring it doesn't affect standard test runs.Changes
Testing improvements:
TestKTOTrainerSlowtotests/experimental/test_kto_trainer.py, including a test for training the fullgoogle/gemma-3n-E2B-itmodel using theKTOTrainer. The test is marked with@pytest.mark.slowand skipped by default because the model is gated and requires a Hugging Face token.Note
Low Risk
Test-only change with no production code paths; the new test is skipped in default CI runs.
Overview
Adds
TestKTOTrainerSlowtotests/experimental/test_kto_trainer.py, mirroring the existing DPO slow-test pattern for Gemma 3n VLMs.The new
test_train_vlm_gemma_3nruns end-to-endKTOTrainertraining ongoogle/gemma-3n-E2B-itwith the conversational unpaired image preference dataset, checkstrain_lossis logged, and asserts trainable weights change exceptmodel.audio_tower/model.embed_audio(no audio in the dataset). It uses VLM-friendly settings (max_length=None, batch size 1, bfloat16) and is marked@pytest.mark.slowplus@pytest.mark.skipbecause the checkpoint is gated and needs an HF token.Reviewed by Cursor Bugbot for commit 8bf9334. Bugbot is set up for automated code reviews on this repo. Configure here.