diff --git a/trl/experimental/kto/kto_trainer.py b/trl/experimental/kto/kto_trainer.py index 8b38a8ccb34..98824a27724 100644 --- a/trl/experimental/kto/kto_trainer.py +++ b/trl/experimental/kto/kto_trainer.py @@ -377,7 +377,24 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: class KTOTrainer(_BaseTrainer): """ - Initialize KTOTrainer. + Trainer for Kahneman-Tversky Optimization (KTO) method. This algorithm was initially proposed in the paper [KTO: + Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306). This class is a + wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + >>> from trl.experimental.kto import KTOTrainer + >>> from datasets import load_dataset + + >>> dataset = load_dataset("trl-lib/kto-mix-14k", split="train") + + >>> trainer = KTOTrainer( + ... model="Qwen/Qwen2.5-0.5B-Instruct", + ... train_dataset=dataset, + ... ) + >>> trainer.train() + ``` Args: model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): @@ -398,28 +415,44 @@ class KTOTrainer(_BaseTrainer): state before KTO training starts. args ([`experimental.kto.KTOConfig`], *optional*): Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~experimental.kto.kto_trainer.DataCollatorForUnpairedPreference`] if the model is a + language model and [`~experimental.kto.kto_trainer.DataCollatorForVisionUnpairedPreference`] if the model + is a vision-language model. Custom collators must truncate sequences before padding; the trainer does not + apply post-collation truncation. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): - The dataset to use for training. + Dataset to use for training. This trainer supports [unpaired preference](#unpaired-preference) type. The + format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): - The dataset to use for evaluation. + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*): Processing class used to process the data. The padding side must be set to "left". If `None`, the processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. - data_collator ([`~transformers.DataCollator`], *optional*): - The data collator to use for training. If None is specified, the default data collator - ([`~experimental.kto.kto_trainer.DataCollatorForUnpairedPreference`]) will be used which will pad the - sequences to the maximum length of the sequences in the batch. - callbacks (`list[transformers.TrainerCallback]`): - The callbacks to use for training. - optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): - The optimizer and scheduler to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing + [`SFTConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean + `compute_result` argument. This will be triggered after the last eval batch to signal that the function + needs to calculate and return the global summary statistics rather than accumulating the batch-level + statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. - compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): - The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to - metric values. """ _tag_names = ["trl", "kto"] @@ -442,14 +475,14 @@ def __init__( model: "str | PreTrainedModel | PeftModel", ref_model: PreTrainedModel | None = None, args: KTOConfig | None = None, + data_collator: DataCollator | None = None, train_dataset: Dataset | IterableDataset | None = None, eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, - data_collator: DataCollator | None = None, + compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, callbacks: list[TrainerCallback] | None = None, - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: "PeftConfig | None" = None, - compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, ): # Args if args is None: