Skip to content

Align KTO with DPO: Add TestKTOTrainerSlow with test_train_vlm_gemma_3n#6162

Open
albertvillanova wants to merge 3 commits into
mainfrom
align-kto-dpo-tests-vlm-slow
Open

Align KTO with DPO: Add TestKTOTrainerSlow with test_train_vlm_gemma_3n#6162
albertvillanova wants to merge 3 commits into
mainfrom
align-kto-dpo-tests-vlm-slow

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Jun 24, 2026

Copy link
Copy Markdown
Member

Align KTO with DPO: Add TestKTOTrainerSlow with test_train_vlm_gemma_3n.

Part of:

This PR adds a new slow-running test class to improve coverage for vision-language model (VLM) training, specifically targeting the google/gemma-3n-E2B-it model. The new test is marked as slow and skipped by default due to external requirements, ensuring it doesn't affect standard test runs.

Changes

Testing improvements:

  • Added a new slow test class TestKTOTrainerSlow to tests/experimental/test_kto_trainer.py, including a test for training the full google/gemma-3n-E2B-it model using the KTOTrainer. The test is marked with @pytest.mark.slow and skipped by default because the model is gated and requires a Hugging Face token.
  • The new test ensures that all model parameters (except those related to audio, which are not used by the dataset) are updated during training, increasing test coverage for VLMs with timm encoders.

Note

Low Risk
Test-only change with no production code paths; the new test is skipped in default CI runs.

Overview
Adds TestKTOTrainerSlow to tests/experimental/test_kto_trainer.py, mirroring the existing DPO slow-test pattern for Gemma 3n VLMs.

The new test_train_vlm_gemma_3n runs end-to-end KTOTrainer training on google/gemma-3n-E2B-it with the conversational unpaired image preference dataset, checks train_loss is logged, and asserts trainable weights change except model.audio_tower / model.embed_audio (no audio in the dataset). It uses VLM-friendly settings (max_length=None, batch size 1, bfloat16) and is marked @pytest.mark.slow plus @pytest.mark.skip because the checkpoint is gated and needs an HF token.

Reviewed by Cursor Bugbot for commit 8bf9334. Bugbot is set up for automated code reviews on this repo. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 8bf9334. Configure here.

model_init_kwargs={"dtype": "bfloat16"},
report_to="none",
)
trainer = KTOTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch size one breaks KTOTrainer

Medium Severity

The slow Gemma VLM test sets per_device_train_batch_size=1 while using the default KTO loss, which keeps the KL term enabled. KTOTrainer raises during initialization when KL is on and the actual batch size is at most one, so the test fails as soon as the skip is removed—not during training.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 8bf9334. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant