Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 50 additions & 17 deletions trl/experimental/kto/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,24 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:

class KTOTrainer(_BaseTrainer):
"""
Initialize KTOTrainer.
Trainer for Kahneman-Tversky Optimization (KTO) method. This algorithm was initially proposed in the paper [KTO:
Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306). This class is a
wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.

Example:

```python
>>> from trl.experimental.kto import KTOTrainer
>>> from datasets import load_dataset

>>> dataset = load_dataset("trl-lib/kto-mix-14k", split="train")

>>> trainer = KTOTrainer(
... model="Qwen/Qwen2.5-0.5B-Instruct",
... train_dataset=dataset,
... )
>>> trainer.train()
```

Args:
model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
Expand All @@ -400,28 +417,44 @@ class KTOTrainer(_BaseTrainer):
state before KTO training starts.
args ([`experimental.kto.KTOConfig`], *optional*):
Configuration for this trainer. If `None`, a default configuration is used.
data_collator ([`~transformers.DataCollator`], *optional*):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~experimental.kto.kto_trainer.DataCollatorForUnpairedPreference`] if the model is a
language model and [`~experimental.kto.kto_trainer.DataCollatorForVisionUnpairedPreference`] if the model
is a vision-language model. Custom collators must truncate sequences before padding; the trainer does not
apply post-collation truncation.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
The dataset to use for training.
Dataset to use for training. This trainer supports [unpaired preference](#unpaired-preference) type. The
format of the samples can be either:

- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
The dataset to use for evaluation.
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
`tokenizer.eos_token` will be used as the default.
data_collator ([`~transformers.DataCollator`], *optional*):
The data collator to use for training. If None is specified, the default data collator
([`~experimental.kto.kto_trainer.DataCollatorForUnpairedPreference`]) will be used which will pad the
sequences to the maximum length of the sequences in the batch.
callbacks (`list[transformers.TrainerCallback]`):
The callbacks to use for training.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
[`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean
`compute_result` argument. This will be triggered after the last eval batch to signal that the function
needs to calculate and return the global summary statistics rather than accumulating the batch-level
statistics.
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
in [here](https://huggingface.co/docs/transformers/main_classes/callback).

If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to
metric values.
"""

_tag_names = ["trl", "kto"]
Expand All @@ -444,14 +477,14 @@ def __init__(
model: "str | PreTrainedModel | PeftModel",
ref_model: PreTrainedModel | None = None,
args: KTOConfig | None = None,
data_collator: DataCollator | None = None,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve positional KTOTrainer argument order

This moves data_collator ahead of train_dataset without making the following arguments keyword-only. Any existing script using the previous positional signature, for example KTOTrainer(model, ref_model, args, train_dataset), now binds the dataset to data_collator and leaves train_dataset as None, so initialization raises ValueError("train_dataset is required"); the public trl.KTOTrainer wrapper also forwards positional args directly. Please keep the old positional slots or add a compatibility shim before reordering.

Useful? React with 👍 / 👎.

train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
data_collator: DataCollator | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
peft_config: "PeftConfig | None" = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
):
# Args
if args is None:
Expand Down
Loading