Skip to content
141 changes: 32 additions & 109 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,71 +238,6 @@ def test_different_pad_token_id(self):
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))

def test_max_length_keep_start(self):
"""Test that sequences longer than max_length are truncated from the start."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3)
examples = [
{"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5]},
{"input_ids": [6, 7, 8], "labels": [6, 7, 8]},
]

result = collator(examples)

assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [6, 7, 8]]))

def test_max_length_keep_end(self):
"""Test that sequences longer than max_length are truncated from the end (keeping last tokens)."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="keep_end")
examples = [
{"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5]},
{"input_ids": [6, 7, 8], "labels": [6, 7, 8]},
]

result = collator(examples)

assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[3, 4, 5], [6, 7, 8]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[3, 4, 5], [6, 7, 8]]))

def test_max_length_no_truncation_needed(self):
"""Test that max_length larger than sequences does not alter the output."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=10)
examples = [{"input_ids": [1, 2, 3], "labels": [1, 2, 3]}, {"input_ids": [4, 5], "labels": [4, 5]}]

result = collator(examples)

assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))

def test_max_length_without_labels(self):
"""Truncation without labels: labels default to the input IDs and are truncated with the same window."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3)
examples = [{"input_ids": [1, 2, 3, 4, 5]}, {"input_ids": [6, 7, 8]}]

result = collator(examples)

assert set(result.keys()) == {"input_ids", "attention_mask", "labels"}
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [6, 7, 8]]))

def test_max_length_invalid_truncation_mode(self):
"""Test that an invalid truncation_mode raises ValueError."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="invalid")
examples = [
{"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5]},
{"input_ids": [6, 7, 8], "labels": [6, 7, 8]},
]

with pytest.raises(ValueError, match="Unsupported truncation mode"):
collator(examples)

def test_single_example_single_doc(self):
batch_seq_lengths = [[5]]
result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths)
Expand Down Expand Up @@ -1052,25 +987,6 @@ def tokenize_example(example):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_skip_prepare_dataset_passes_truncation_to_text_collator(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]")
with pytest.warns(FutureWarning, match="keep_end.*deprecated"):
training_args = SFTConfig(
output_dir=self.tmp_dir,
max_length=16,
truncation_mode="keep_end",
dataset_kwargs={"skip_prepare_dataset": True},
report_to="none",
)

trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)

assert isinstance(trainer.data_collator, DataCollatorForLanguageModeling)
assert trainer.data_collator.max_length == 16
assert trainer.data_collator.truncation_mode == "keep_end"

def test_dataset_with_transform_requires_skip_prepare_dataset(self):
dataset = Dataset.from_dict({"text": ["hello world"]})

Expand Down Expand Up @@ -1304,7 +1220,7 @@ def test_train_assistant_only(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_dataset_prep_builds_labels_for_assistant_only_loss(self):
def test_dataset_preparation_builds_labels_for_assistant_only_loss(self):
"""Dataset preparation must bake the assistant masks into a labels column."""
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

Expand All @@ -1315,18 +1231,16 @@ def test_dataset_prep_builds_labels_for_assistant_only_loss(self):

assert "labels" in trainer.train_dataset.column_names
for example in trainer.train_dataset:
assert len(example["labels"]) == len(example["input_ids"])
expected = [
token_id if mask else -100
for token_id, mask in zip(example["input_ids"], example["assistant_masks"], strict=True)
]
assert example["labels"] == expected
labels, input_ids = example["labels"], example["input_ids"]
assert len(labels) == len(input_ids)
# Labels are input_ids with non-assistant tokens masked to -100.
assert all(label == -100 or label == token_id for label, token_id in zip(labels, input_ids, strict=True))
assert any(label != -100 for label in labels) # assistant tokens contribute to the loss
assert any(label == -100 for label in labels) # non-assistant tokens are masked

def test_labels_all_masked_after_truncation(self):
"""Regression test for #3927: when the assistant response lies beyond `max_length`, dataset preparation
builds labels that still hold real token IDs, but the slice surviving the collator's truncation is all -100
(the prompt). The bug was masking happening after truncation; building labels before truncation makes this
surfaceable."""
"""Regression test for #3927. When the assistant turn lies entirely beyond `max_length`, truncation keeps only
prompt tokens, which are all -100."""
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

# `max_length` is small enough that the kept prefix is entirely prompt tokens (the assistant turn comes later).
Expand All @@ -1335,15 +1249,23 @@ def test_labels_all_masked_after_truncation(self):
model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
)

# Before truncation, the prepared labels contain real (non -100) assistant token IDs.
labels = trainer.train_dataset[0]["labels"]
assert any(token_id != -100 for token_id in labels)
assert all(token_id == -100 for token_id in labels)

def test_dataset_truncated_to_max_length(self):
"""Dataset preparation truncates every example to `max_length`."""
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

training_args = SFTConfig(output_dir=self.tmp_dir, max_length=4, report_to="none")
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)

# After the collator truncates to `max_length` (keep_start), the surviving labels are all -100.
batch = trainer.data_collator([trainer.train_dataset[0]])
assert batch["labels"].eq(-100).all()
for example in trainer.train_dataset:
assert len(example["input_ids"]) <= 4
assert len(example["labels"]) <= 4

def test_dataset_prep_builds_labels_for_completion_only(self):
def test_dataset_preparation_builds_labels_for_completion_only(self):
"""Dataset preparation must bake the completion mask into a labels column when completion_only_loss
resolves to True (the default for prompt-completion datasets)."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train")
Expand All @@ -1355,13 +1277,14 @@ def test_dataset_prep_builds_labels_for_completion_only(self):

assert "labels" in trainer.train_dataset.column_names
for example in trainer.train_dataset:
expected = [
token_id if mask else -100
for token_id, mask in zip(example["input_ids"], example["completion_mask"], strict=True)
]
assert example["labels"] == expected

def test_dataset_prep_respects_existing_labels(self):
labels, input_ids = example["labels"], example["input_ids"]
assert len(labels) == len(input_ids)
# Labels are input_ids with prompt tokens masked to -100.
assert all(label == -100 or label == token_id for label, token_id in zip(labels, input_ids, strict=True))
assert any(label != -100 for label in labels) # completion tokens contribute to the loss
assert any(label == -100 for label in labels) # prompt tokens are masked

def test_dataset_preparation_respects_existing_labels(self):
"""A user-provided labels column must be taken as is, even when mask columns are also present."""
dataset = Dataset.from_list(
[
Expand All @@ -1377,7 +1300,7 @@ def test_dataset_prep_respects_existing_labels(self):

assert trainer.train_dataset[:]["labels"] == [[1, -100, 3, -100], [-100, 6]]

def test_dataset_prep_builds_labels_for_pretokenized_with_masks(self):
def test_dataset_preparation_builds_labels_for_pretokenized_with_masks(self):
"""Pre-tokenized datasets that carry mask columns but no labels must get labels built at preparation."""
dataset = Dataset.from_list(
[
Expand Down
67 changes: 36 additions & 31 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,10 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch.

This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key.
If the input contains `"labels"`, they are used as is (truncated and padded like the input IDs); otherwise the
labels default to the input IDs. Tokens that shouldn't contribute to the loss are expected to be already set to
`-100` in the labels; the [`SFTTrainer`] takes care of this during dataset preparation. The collator returns a
dictionary containing the following keys:
If the input contains `"labels"`, they are used as is (padded like the input IDs); otherwise the labels default to
the input IDs. Tokens that shouldn't contribute to the loss are expected to be already set to `-100` in the labels;
the [`SFTTrainer`] takes care of this during dataset preparation. The collator returns a dictionary containing the
following keys:
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch.
- `"labels"`: Tensor of labels, padded with `-100` to the maximum length of the batch. If `padding_free` is set
to `False`, the following key is also returned:
Expand All @@ -413,12 +413,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
Args:
pad_token_id (`int`):
Token ID to use for padding.
max_length (`int`, *optional*):
Maximum length of the sequences in the batch. Sequences longer than `max_length` are truncated to
`max_length`.
truncation_mode (`str`, *optional*, defaults to `"keep_start"`):
Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
`"keep_start"`.
padding_free (`bool`, *optional*, defaults to `False`):
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
generated accordingly and returned instead of the attention mask.
Expand Down Expand Up @@ -464,8 +458,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
"""

pad_token_id: int
max_length: int | None = None
truncation_mode: str = "keep_start"
padding_free: bool = False
pad_to_multiple_of: int | None = None
return_tensors: str = "pt"
Expand All @@ -475,19 +467,6 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
batch_seq_lengths = [example["seq_lengths"] for example in examples] if "seq_lengths" in examples[0] else None
labels = [example.get("labels", example["input_ids"]) for example in examples]

# Truncate per sequence if necessary
if self.max_length is not None and not self.padding_free:
if self.truncation_mode == "keep_start":
sl = slice(None, self.max_length)
elif self.truncation_mode == "keep_end":
sl = slice(-self.max_length, None)
else:
raise ValueError(
f"Unsupported truncation mode: {self.truncation_mode}, expected 'keep_start' or 'keep_end'"
)
input_ids = [ids[sl] for ids in input_ids]
labels = [lbl[sl] for lbl in labels]

# Convert to tensor
input_ids = [torch.tensor(ids) for ids in input_ids]
labels = [torch.tensor(lbl) for lbl in labels]
Expand Down Expand Up @@ -851,8 +830,6 @@ class SFTTrainer(_BaseTrainer):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model
and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
Custom collators must truncate sequences before padding; the trainer does not apply post-collation
truncation.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
Expand All @@ -861,7 +838,12 @@ class SFTTrainer(_BaseTrainer):
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).

The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
The trainer also supports pre-tokenized datasets, recognized by a required `input_ids` column. An optional
`labels` column (`-100` on tokens excluded from the loss) is used as is if present; otherwise labels are
built from the optional `assistant_masks` / `completion_mask` columns (which are folded in then dropped,
`completion_mask` only when `completion_only_loss=True`), or default to a copy of `input_ids`. Sequences
are truncated to `max_length` during preparation. With `skip_prepare_dataset=True`, preparation is skipped
and the collator is expected to handle the dataset as is.
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
Expand Down Expand Up @@ -1190,8 +1172,6 @@ def __init__(
self._tokenizer.pad_token = pad_token
data_collator = DataCollatorForLanguageModeling(
pad_token_id=self._tokenizer.pad_token_id,
max_length=None if self.padding_free else args.max_length,
truncation_mode=args.truncation_mode,
padding_free=self.padding_free,
pad_to_multiple_of=args.pad_to_multiple_of,
)
Expand Down Expand Up @@ -1584,7 +1564,32 @@ def build_labels(example, mask_columns):
]
return {"labels": labels}

dataset = dataset.map(build_labels, fn_kwargs={"mask_columns": mask_columns}, **map_kwargs)
dataset = dataset.map(
build_labels,
fn_kwargs={"mask_columns": mask_columns},
remove_columns=mask_columns,
**map_kwargs,
)

# Truncate to max_length. Skipped when packing, since packing already chunks sequences to max_length.
# Done here, during preparation, so the result is cached. When preparation is skipped
# (`skip_prepare_dataset=True`), no truncation is applied and the dataset must already be truncated.
if args.max_length is not None and not packing:
if args.truncation_mode == "keep_start":
sl = slice(None, args.max_length)
elif args.truncation_mode == "keep_end":
sl = slice(-args.max_length, None)
else:
raise ValueError(
f"Unsupported truncation mode: {args.truncation_mode}, expected 'keep_start' or 'keep_end'"
)
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"

def truncate(example, sl):
return {"input_ids": example["input_ids"][sl], "labels": example["labels"][sl]}
Comment thread
qgallouedec marked this conversation as resolved.

dataset = dataset.map(truncate, fn_kwargs={"sl": sl}, **map_kwargs)

# Pack
if packing:
Expand Down
Loading