Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 40 additions & 40 deletions tests/experimental/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TestDataCollatorForVisionUnpairedPreference(TrlTestCase):
)
def test_mm_token_type_ids_shape(self):
# Regression guard: when the processor returns mm_token_type_ids (Qwen2.5-VL after transformers#43972),
# the collator must produce a KL_completion_token_type_ids whose width matches KL_completion_input_ids,
# the collator must produce a KL_token_type_ids whose width matches KL_input_ids,
# not the main completion's width (the two differ whenever their text lengths differ).
from PIL import Image
from transformers import AutoProcessor
Expand All @@ -68,14 +68,14 @@ def test_mm_token_type_ids_shape(self):
output = collator(examples)

assert "mm_token_type_ids" in output
assert output["mm_token_type_ids"].shape == output["completion_input_ids"].shape, (
assert output["mm_token_type_ids"].shape == output["input_ids"].shape, (
f"mm_token_type_ids shape {output['mm_token_type_ids'].shape} != "
f"completion_input_ids shape {output['completion_input_ids'].shape}"
f"input_ids shape {output['input_ids'].shape}"
)
assert "KL_completion_mm_token_type_ids" in output
assert output["KL_completion_mm_token_type_ids"].shape == output["KL_completion_input_ids"].shape, (
f"KL_completion_mm_token_type_ids shape {output['KL_completion_mm_token_type_ids'].shape} != "
f"KL_completion_input_ids shape {output['KL_completion_input_ids'].shape}"
assert "KL_mm_token_type_ids" in output
assert output["KL_mm_token_type_ids"].shape == output["KL_input_ids"].shape, (
f"KL_mm_token_type_ids shape {output['KL_mm_token_type_ids'].shape} != "
f"KL_input_ids shape {output['KL_input_ids'].shape}"
)

def test_output_keys(self):
Expand Down Expand Up @@ -104,16 +104,16 @@ def make_examples():
# With KL
collator = DataCollatorForVisionUnpairedPreference(processor, calculate_kl=True)
output = collator(make_examples())
for key in ["completion_input_ids", "completion_attention_mask", "completion_mask", "pixel_values", "label"]:
for key in ["input_ids", "attention_mask", "completion_mask", "pixel_values", "label"]:
assert key in output, f"Missing key: {key}"
for key in ["KL_completion_input_ids", "KL_completion_attention_mask", "KL_completion_mask"]:
for key in ["KL_input_ids", "KL_attention_mask", "KL_completion_mask"]:
assert key in output, f"Missing KL key: {key}"

# Without KL
collator_no_kl = DataCollatorForVisionUnpairedPreference(processor, calculate_kl=False)
output_no_kl = collator_no_kl(make_examples())
assert "completion_input_ids" in output_no_kl
assert "KL_completion_input_ids" not in output_no_kl
assert "input_ids" in output_no_kl
assert "KL_input_ids" not in output_no_kl

def test_kl_cycling(self):
# The KL completion for example i must be the completion from example i-1 (cycled by +1).
Expand Down Expand Up @@ -141,8 +141,8 @@ def test_kl_cycling(self):
output = collator(examples)
# KL completions are cycled: KL[0] = completion[-1], KL[1] = completion[0]
# They must differ from the matching main completion (unless both are identical strings, which they aren't here)
assert not torch.equal(output["completion_input_ids"][0], output["KL_completion_input_ids"][0])
assert not torch.equal(output["completion_input_ids"][1], output["KL_completion_input_ids"][1])
assert not torch.equal(output["input_ids"][0], output["KL_input_ids"][0])
assert not torch.equal(output["input_ids"][1], output["KL_input_ids"][1])


class TestDataCollatorForUnpairedPreference(TrlTestCase):
Expand All @@ -154,13 +154,13 @@ def test_padding_and_masks(self):
]
result = collator(examples)

expected_completion_input_ids = torch.tensor(
expected_input_ids = torch.tensor(
[
[1, 2, 3, 4, 5], # prompt + completion (example 1)
[7, 8, 9, 10, 0], # prompt + completion (example 2, padded)
]
)
expected_completion_attention_mask = torch.tensor(
expected_attention_mask = torch.tensor(
[
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 0],
Expand All @@ -172,13 +172,13 @@ def test_padding_and_masks(self):
[0, 0, 1, 1, 0], # completion (example 2, padded)
]
)
expected_kl_completion_input_ids = torch.tensor(
expected_kl_input_ids = torch.tensor(
[
[1, 2, 3, 6, 0], # prompt + KL completion (example 1, padded)
[7, 8, 11, 12, 13], # prompt + KL completion (example 2)
]
)
expected_kl_completion_attention_mask = torch.tensor(
expected_kl_attention_mask = torch.tensor(
[
[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1],
Expand All @@ -192,19 +192,19 @@ def test_padding_and_masks(self):
)

assert set(result.keys()) == {
"completion_input_ids",
"completion_attention_mask",
"input_ids",
"attention_mask",
"completion_mask",
"KL_completion_input_ids",
"KL_completion_attention_mask",
"KL_input_ids",
"KL_attention_mask",
"KL_completion_mask",
"label",
}
torch.testing.assert_close(result["completion_input_ids"], expected_completion_input_ids)
torch.testing.assert_close(result["completion_attention_mask"], expected_completion_attention_mask)
torch.testing.assert_close(result["input_ids"], expected_input_ids)
torch.testing.assert_close(result["attention_mask"], expected_attention_mask)
torch.testing.assert_close(result["completion_mask"], expected_completion_mask)
torch.testing.assert_close(result["KL_completion_input_ids"], expected_kl_completion_input_ids)
torch.testing.assert_close(result["KL_completion_attention_mask"], expected_kl_completion_attention_mask)
torch.testing.assert_close(result["KL_input_ids"], expected_kl_input_ids)
torch.testing.assert_close(result["KL_attention_mask"], expected_kl_attention_mask)
torch.testing.assert_close(result["KL_completion_mask"], expected_kl_completion_mask)
assert result["label"] == [True, False]

Expand Down Expand Up @@ -234,11 +234,11 @@ def test_optional_reference_logps(self):
expected_ref_kl_logps = torch.tensor([0.2, 0.4])

assert set(result.keys()) == {
"completion_input_ids",
"completion_attention_mask",
"input_ids",
"attention_mask",
"completion_mask",
"KL_completion_input_ids",
"KL_completion_attention_mask",
"KL_input_ids",
"KL_attention_mask",
"KL_completion_mask",
"ref_logps",
"ref_KL_logps",
Expand All @@ -255,30 +255,30 @@ def test_with_pad_to_multiple_of(self):
]
result = collator(examples)

expected_completion_input_ids = torch.tensor(
expected_input_ids = torch.tensor(
[
[1, 2, 0, 0, 0], # prompt + completion (example 1, padded to multiple of 5)
[4, 5, 6, 7, 0], # prompt + completion (example 2)
]
)
expected_kl_completion_input_ids = torch.tensor(
expected_kl_input_ids = torch.tensor(
[
[1, 3, 0, 0, 0], # prompt + KL completion (example 1, padded to multiple of 5)
[4, 5, 8, 9, 0], # prompt + KL completion (example 2)
]
)

assert set(result.keys()) == {
"completion_input_ids",
"completion_attention_mask",
"input_ids",
"attention_mask",
"completion_mask",
"KL_completion_input_ids",
"KL_completion_attention_mask",
"KL_input_ids",
"KL_attention_mask",
"KL_completion_mask",
"label",
}
torch.testing.assert_close(result["completion_input_ids"], expected_completion_input_ids)
torch.testing.assert_close(result["KL_completion_input_ids"], expected_kl_completion_input_ids)
torch.testing.assert_close(result["input_ids"], expected_input_ids)
torch.testing.assert_close(result["KL_input_ids"], expected_kl_input_ids)


class TestKTOTrainer(TrlTestCase):
Expand Down Expand Up @@ -452,8 +452,8 @@ def test_tokenize_and_process_tokens(self):
# Verify the collator output (assembly, BOS/EOS insertion, labels).
example = trainer.train_dataset[0]
batch = trainer.data_collator([example])
# completion_input_ids ends with EOS
assert batch["completion_input_ids"][0, -1].item() == self.tokenizer.eos_token_id
# input_ids ends with EOS
assert batch["input_ids"][0, -1].item() == self.tokenizer.eos_token_id
# completion_mask: prompt tokens are 0, completion tokens are 1; at least the prompt is masked
assert "completion_mask" in batch
completion_mask = batch["completion_mask"][0].tolist()
Expand Down Expand Up @@ -1205,7 +1205,7 @@ def test_train_vlm_text_only_data(self, model_id, dataset_config):
assert not torch.equal(param, new_param), f"Param {n} is not updated"

def test_train_vlm_with_max_length(self):
# Regression test: mm_token_type_ids (and KL_completion_mm_token_type_ids) must be truncated alongside
# Regression test: mm_token_type_ids (and KL_mm_token_type_ids) must be truncated alongside
# input_ids when max_length is set, otherwise a shape mismatch crashes the model forward pass.
# max_length=37 truncates 1 completion token (total_len=38) while keeping all image tokens (prompt_len=34) safe.
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_unpaired_preference", split="train")
Expand Down
Loading
Loading