Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/experimental/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,3 +1268,37 @@ def test_train_vlm_liger(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Param {n} is not updated"


@pytest.mark.slow
class TestKTOTrainerSlow(TrlTestCase):
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
@pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token")
@require_vision
def test_train_vlm_gemma_3n(self):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_unpaired_preference", split="train")

training_args = KTOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
max_length=None, # for VLMs, truncating can remove image tokens, leading to errors
per_device_train_batch_size=1, # VLM training is memory intensive, reduce batch size to avoid OOM
model_init_kwargs={"dtype": "bfloat16"},
report_to="none",
)
trainer = KTOTrainer(model="google/gemma-3n-E2B-it", args=training_args, train_dataset=dataset)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch size one breaks KTOTrainer

Medium Severity

The slow Gemma VLM test sets per_device_train_batch_size=1 while using the default KTO loss, which keeps the KL term enabled. KTOTrainer raises during initialization when KL is on and the actual batch size is at most one, so the test fails as soon as the skip is removed—not during training.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 8bf9334. Configure here.


previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "model.audio_tower" in n or "model.embed_audio" in n:
# The audio embedding parameters are not updated because this dataset contains no audio data
continue
assert not torch.equal(param, new_param), f"Param {n} is not updated"
Loading