diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 7e52f700b49..517c032fe14 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -238,71 +238,6 @@ def test_different_pad_token_id(self): torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) - def test_max_length_keep_start(self): - """Test that sequences longer than max_length are truncated from the start.""" - collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3) - examples = [ - {"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5]}, - {"input_ids": [6, 7, 8], "labels": [6, 7, 8]}, - ] - - result = collator(examples) - - assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} - torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]])) - torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) - torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [6, 7, 8]])) - - def test_max_length_keep_end(self): - """Test that sequences longer than max_length are truncated from the end (keeping last tokens).""" - collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="keep_end") - examples = [ - {"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5]}, - {"input_ids": [6, 7, 8], "labels": [6, 7, 8]}, - ] - - result = collator(examples) - - assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} - torch.testing.assert_close(result["input_ids"], torch.tensor([[3, 4, 5], [6, 7, 8]])) - torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) - torch.testing.assert_close(result["labels"], torch.tensor([[3, 4, 5], [6, 7, 8]])) - - def test_max_length_no_truncation_needed(self): - """Test that max_length larger than sequences does not alter the output.""" - collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=10) - examples = [{"input_ids": [1, 2, 3], "labels": [1, 2, 3]}, {"input_ids": [4, 5], "labels": [4, 5]}] - - result = collator(examples) - - assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} - torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) - torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) - torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) - - def test_max_length_without_labels(self): - """Truncation without labels: labels default to the input IDs and are truncated with the same window.""" - collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3) - examples = [{"input_ids": [1, 2, 3, 4, 5]}, {"input_ids": [6, 7, 8]}] - - result = collator(examples) - - assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} - torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 8]])) - torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 1]])) - torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [6, 7, 8]])) - - def test_max_length_invalid_truncation_mode(self): - """Test that an invalid truncation_mode raises ValueError.""" - collator = DataCollatorForLanguageModeling(pad_token_id=0, max_length=3, truncation_mode="invalid") - examples = [ - {"input_ids": [1, 2, 3, 4, 5], "labels": [1, 2, 3, 4, 5]}, - {"input_ids": [6, 7, 8], "labels": [6, 7, 8]}, - ] - - with pytest.raises(ValueError, match="Unsupported truncation mode"): - collator(examples) - def test_single_example_single_doc(self): batch_seq_lengths = [[5]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) @@ -1052,25 +987,6 @@ def tokenize_example(example): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - def test_skip_prepare_dataset_passes_truncation_to_text_collator(self): - dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:2]") - with pytest.warns(FutureWarning, match="keep_end.*deprecated"): - training_args = SFTConfig( - output_dir=self.tmp_dir, - max_length=16, - truncation_mode="keep_end", - dataset_kwargs={"skip_prepare_dataset": True}, - report_to="none", - ) - - trainer = SFTTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset - ) - - assert isinstance(trainer.data_collator, DataCollatorForLanguageModeling) - assert trainer.data_collator.max_length == 16 - assert trainer.data_collator.truncation_mode == "keep_end" - def test_dataset_with_transform_requires_skip_prepare_dataset(self): dataset = Dataset.from_dict({"text": ["hello world"]}) @@ -1304,7 +1220,7 @@ def test_train_assistant_only(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - def test_dataset_prep_builds_labels_for_assistant_only_loss(self): + def test_dataset_preparation_builds_labels_for_assistant_only_loss(self): """Dataset preparation must bake the assistant masks into a labels column.""" dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") @@ -1315,18 +1231,16 @@ def test_dataset_prep_builds_labels_for_assistant_only_loss(self): assert "labels" in trainer.train_dataset.column_names for example in trainer.train_dataset: - assert len(example["labels"]) == len(example["input_ids"]) - expected = [ - token_id if mask else -100 - for token_id, mask in zip(example["input_ids"], example["assistant_masks"], strict=True) - ] - assert example["labels"] == expected + labels, input_ids = example["labels"], example["input_ids"] + assert len(labels) == len(input_ids) + # Labels are input_ids with non-assistant tokens masked to -100. + assert all(label == -100 or label == token_id for label, token_id in zip(labels, input_ids, strict=True)) + assert any(label != -100 for label in labels) # assistant tokens contribute to the loss + assert any(label == -100 for label in labels) # non-assistant tokens are masked def test_labels_all_masked_after_truncation(self): - """Regression test for #3927: when the assistant response lies beyond `max_length`, dataset preparation - builds labels that still hold real token IDs, but the slice surviving the collator's truncation is all -100 - (the prompt). The bug was masking happening after truncation; building labels before truncation makes this - surfaceable.""" + """Regression test for #3927. When the assistant turn lies entirely beyond `max_length`, truncation keeps only + prompt tokens, which are all -100.""" dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") # `max_length` is small enough that the kept prefix is entirely prompt tokens (the assistant turn comes later). @@ -1335,15 +1249,23 @@ def test_labels_all_masked_after_truncation(self): model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset ) - # Before truncation, the prepared labels contain real (non -100) assistant token IDs. labels = trainer.train_dataset[0]["labels"] - assert any(token_id != -100 for token_id in labels) + assert all(token_id == -100 for token_id in labels) + + def test_dataset_truncated_to_max_length(self): + """Dataset preparation truncates every example to `max_length`.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + training_args = SFTConfig(output_dir=self.tmp_dir, max_length=4, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) - # After the collator truncates to `max_length` (keep_start), the surviving labels are all -100. - batch = trainer.data_collator([trainer.train_dataset[0]]) - assert batch["labels"].eq(-100).all() + for example in trainer.train_dataset: + assert len(example["input_ids"]) <= 4 + assert len(example["labels"]) <= 4 - def test_dataset_prep_builds_labels_for_completion_only(self): + def test_dataset_preparation_builds_labels_for_completion_only(self): """Dataset preparation must bake the completion mask into a labels column when completion_only_loss resolves to True (the default for prompt-completion datasets).""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") @@ -1355,13 +1277,14 @@ def test_dataset_prep_builds_labels_for_completion_only(self): assert "labels" in trainer.train_dataset.column_names for example in trainer.train_dataset: - expected = [ - token_id if mask else -100 - for token_id, mask in zip(example["input_ids"], example["completion_mask"], strict=True) - ] - assert example["labels"] == expected - - def test_dataset_prep_respects_existing_labels(self): + labels, input_ids = example["labels"], example["input_ids"] + assert len(labels) == len(input_ids) + # Labels are input_ids with prompt tokens masked to -100. + assert all(label == -100 or label == token_id for label, token_id in zip(labels, input_ids, strict=True)) + assert any(label != -100 for label in labels) # completion tokens contribute to the loss + assert any(label == -100 for label in labels) # prompt tokens are masked + + def test_dataset_preparation_respects_existing_labels(self): """A user-provided labels column must be taken as is, even when mask columns are also present.""" dataset = Dataset.from_list( [ @@ -1377,7 +1300,7 @@ def test_dataset_prep_respects_existing_labels(self): assert trainer.train_dataset[:]["labels"] == [[1, -100, 3, -100], [-100, 6]] - def test_dataset_prep_builds_labels_for_pretokenized_with_masks(self): + def test_dataset_preparation_builds_labels_for_pretokenized_with_masks(self): """Pre-tokenized datasets that carry mask columns but no labels must get labels built at preparation.""" dataset = Dataset.from_list( [ diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9f2c27344d9..79c7df9fedf 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -399,10 +399,10 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch. This collator expects each example in the input list to be a dictionary containing at least the `"input_ids"` key. - If the input contains `"labels"`, they are used as is (truncated and padded like the input IDs); otherwise the - labels default to the input IDs. Tokens that shouldn't contribute to the loss are expected to be already set to - `-100` in the labels; the [`SFTTrainer`] takes care of this during dataset preparation. The collator returns a - dictionary containing the following keys: + If the input contains `"labels"`, they are used as is (padded like the input IDs); otherwise the labels default to + the input IDs. Tokens that shouldn't contribute to the loss are expected to be already set to `-100` in the labels; + the [`SFTTrainer`] takes care of this during dataset preparation. The collator returns a dictionary containing the + following keys: - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. - `"labels"`: Tensor of labels, padded with `-100` to the maximum length of the batch. If `padding_free` is set to `False`, the following key is also returned: @@ -413,12 +413,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): Args: pad_token_id (`int`): Token ID to use for padding. - max_length (`int`, *optional*): - Maximum length of the sequences in the batch. Sequences longer than `max_length` are truncated to - `max_length`. - truncation_mode (`str`, *optional*, defaults to `"keep_start"`): - Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and - `"keep_start"`. padding_free (`bool`, *optional*, defaults to `False`): If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be generated accordingly and returned instead of the attention mask. @@ -464,8 +458,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): """ pad_token_id: int - max_length: int | None = None - truncation_mode: str = "keep_start" padding_free: bool = False pad_to_multiple_of: int | None = None return_tensors: str = "pt" @@ -475,19 +467,6 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: batch_seq_lengths = [example["seq_lengths"] for example in examples] if "seq_lengths" in examples[0] else None labels = [example.get("labels", example["input_ids"]) for example in examples] - # Truncate per sequence if necessary - if self.max_length is not None and not self.padding_free: - if self.truncation_mode == "keep_start": - sl = slice(None, self.max_length) - elif self.truncation_mode == "keep_end": - sl = slice(-self.max_length, None) - else: - raise ValueError( - f"Unsupported truncation mode: {self.truncation_mode}, expected 'keep_start' or 'keep_end'" - ) - input_ids = [ids[sl] for ids in input_ids] - labels = [lbl[sl] for lbl in labels] - # Convert to tensor input_ids = [torch.tensor(ids) for ids in input_ids] labels = [torch.tensor(lbl) for lbl in labels] @@ -851,8 +830,6 @@ class SFTTrainer(_BaseTrainer): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. Will default to [`~trainer.sft_trainer.DataCollatorForLanguageModeling`] if the model is a language model and [`~trainer.sft_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model. - Custom collators must truncate sequences before padding; the trainer does not apply post-collation - truncation. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: @@ -861,7 +838,12 @@ class SFTTrainer(_BaseTrainer): - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role and content). - The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + The trainer also supports pre-tokenized datasets, recognized by a required `input_ids` column. An optional + `labels` column (`-100` on tokens excluded from the loss) is used as is if present; otherwise labels are + built from the optional `assistant_masks` / `completion_mask` columns (which are folded in then dropped, + `completion_mask` only when `completion_only_loss=True`), or default to a copy of `input_ids`. Sequences + are truncated to `max_length` during preparation. With `skip_prepare_dataset=True`, preparation is skipped + and the collator is expected to handle the dataset as is. eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*): @@ -1190,8 +1172,6 @@ def __init__( self._tokenizer.pad_token = pad_token data_collator = DataCollatorForLanguageModeling( pad_token_id=self._tokenizer.pad_token_id, - max_length=None if self.padding_free else args.max_length, - truncation_mode=args.truncation_mode, padding_free=self.padding_free, pad_to_multiple_of=args.pad_to_multiple_of, ) @@ -1584,7 +1564,32 @@ def build_labels(example, mask_columns): ] return {"labels": labels} - dataset = dataset.map(build_labels, fn_kwargs={"mask_columns": mask_columns}, **map_kwargs) + dataset = dataset.map( + build_labels, + fn_kwargs={"mask_columns": mask_columns}, + remove_columns=mask_columns, + **map_kwargs, + ) + + # Truncate to max_length. Skipped when packing, since packing already chunks sequences to max_length. + # Done here, during preparation, so the result is cached. When preparation is skipped + # (`skip_prepare_dataset=True`), no truncation is applied and the dataset must already be truncated. + if args.max_length is not None and not packing: + if args.truncation_mode == "keep_start": + sl = slice(None, args.max_length) + elif args.truncation_mode == "keep_end": + sl = slice(-args.max_length, None) + else: + raise ValueError( + f"Unsupported truncation mode: {args.truncation_mode}, expected 'keep_start' or 'keep_end'" + ) + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + + def truncate(example, sl): + return {"input_ids": example["input_ids"][sl], "labels": example["labels"][sl]} + + dataset = dataset.map(truncate, fn_kwargs={"sl": sl}, **map_kwargs) # Pack if packing: