Skip to content
16 changes: 6 additions & 10 deletions docs/source/peft_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,13 @@ python trl/scripts/sft.py \

#### Python Example

Pass the `quantization_config` directly to the trainer alongside `peft_config` — the trainer loads and quantizes the model for you. The same `quantization_config` argument is available on [`SFTTrainer`], [`DPOTrainer`], [`GRPOTrainer`], and [`RLOOTrainer`].

```python
import torch

from peft import LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers import BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

# Configure 4-bit quantization
Expand All @@ -464,13 +466,6 @@ bnb_config = BitsAndBytesConfig(
bnb_4bit_use_double_quant=True,
)

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto",
)

# Configure LoRA
peft_config = LoraConfig(
r=32,
Expand All @@ -486,11 +481,12 @@ training_args = SFTConfig(
...
)

# Create trainer with PEFT config
# Create trainer with quantization and PEFT config
trainer = SFTTrainer(
model=model,
model="meta-llama/Llama-2-7b-hf",
args=training_args,
train_dataset=dataset,
quantization_config=bnb_config,
peft_config=peft_config,
)

Expand Down
8 changes: 2 additions & 6 deletions trl/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(script_args, training_args, model_args, dataset_args):
from accelerate.logging import get_logger
from datasets import load_dataset

from trl import DPOTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import DPOTrainer, get_dataset, get_peft_config, get_quantization_config

logger = get_logger(__name__)

Expand All @@ -75,11 +75,6 @@ def main(script_args, training_args, model_args, dataset_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand All @@ -103,6 +98,7 @@ def main(script_args, training_args, model_args, dataset_args):
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
quantization_config=get_quantization_config(model_args),
peft_config=get_peft_config(model_args),
)

Expand Down
8 changes: 2 additions & 6 deletions trl/scripts/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(script_args, training_args, model_args, dataset_args):
from accelerate.logging import get_logger
from datasets import load_dataset

from trl import GRPOTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import GRPOTrainer, get_dataset, get_peft_config, get_quantization_config
from trl.rewards import (
accuracy_reward,
get_soft_overlong_punishment,
Expand Down Expand Up @@ -113,11 +113,6 @@ def main(script_args, training_args, model_args, dataset_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand All @@ -142,6 +137,7 @@ def main(script_args, training_args, model_args, dataset_args):
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
quantization_config=get_quantization_config(model_args),
peft_config=get_peft_config(model_args),
)

Expand Down
8 changes: 2 additions & 6 deletions trl/scripts/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(script_args, training_args, model_args, dataset_args):
from accelerate.logging import get_logger
from datasets import load_dataset

from trl import RewardTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import RewardTrainer, get_dataset, get_peft_config, get_quantization_config

logger = get_logger(__name__)

Expand All @@ -38,11 +38,6 @@ def main(script_args, training_args, model_args, dataset_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand All @@ -66,6 +61,7 @@ def main(script_args, training_args, model_args, dataset_args):
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
quantization_config=get_quantization_config(model_args),
peft_config=get_peft_config(model_args),
)

Expand Down
8 changes: 2 additions & 6 deletions trl/scripts/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(script_args, training_args, model_args, dataset_args):
from accelerate.logging import get_logger
from datasets import load_dataset

from trl import RLOOTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import RLOOTrainer, get_dataset, get_peft_config, get_quantization_config
from trl.rewards import (
accuracy_reward,
get_soft_overlong_punishment,
Expand Down Expand Up @@ -113,11 +113,6 @@ def main(script_args, training_args, model_args, dataset_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand All @@ -142,6 +137,7 @@ def main(script_args, training_args, model_args, dataset_args):
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
quantization_config=get_quantization_config(model_args),
peft_config=get_peft_config(model_args),
)

Expand Down
8 changes: 2 additions & 6 deletions trl/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def main(script_args, training_args, model_args, dataset_args):
from accelerate.logging import get_logger
from datasets import load_dataset

from trl import SFTTrainer, get_dataset, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import SFTTrainer, get_dataset, get_peft_config, get_quantization_config

logger = get_logger(__name__)

Expand All @@ -77,11 +77,6 @@ def main(script_args, training_args, model_args, dataset_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
Expand All @@ -105,6 +100,7 @@ def main(script_args, training_args, model_args, dataset_args):
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
quantization_config=get_quantization_config(model_args),
peft_config=get_peft_config(model_args),
)

Expand Down
17 changes: 17 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.utils.data import DataLoader
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
DataCollator,
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -478,6 +479,9 @@ class DPOTrainer(_BaseTrainer):
optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
quantization_config ([`~transformers.BitsAndBytesConfig`], *optional*):
Quantization configuration used when loading the model from a model identifier. Combine with `peft_config`
for QLoRA training. Ignored if the model is already instantiated.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
"""
Expand Down Expand Up @@ -511,6 +515,7 @@ def __init__(
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
quantization_config: "BitsAndBytesConfig | None" = None,
peft_config: "PeftConfig | None" = None,
):
# Args
Expand All @@ -535,6 +540,13 @@ def __init__(
# Model
if isinstance(model, str):
model_init_kwargs = args.model_init_kwargs or {}
if quantization_config is not None:
if "quantization_config" in model_init_kwargs:
raise ValueError(
"You set `quantization_config` both as a trainer argument and in `args.model_init_kwargs`. "
"Please set it in only one place, preferably as a trainer argument."
)
model_init_kwargs["quantization_config"] = quantization_config
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
Expand All @@ -546,6 +558,11 @@ def __init__(
"You passed `model_init_kwargs` to the `DPOConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if quantization_config is not None:
logger.warning(
"You passed `quantization_config` to the trainer, but your model is already instantiated. The "
"`quantization_config` will be ignored."
)
Comment thread
cursor[bot] marked this conversation as resolved.
# Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
_is_quantized_model = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
if ref_model is model:
Expand Down
17 changes: 17 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AutoModelForSequenceClassification,
AutoProcessor,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -226,6 +227,9 @@ class GRPOTrainer(_BaseTrainer):
optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
quantization_config ([`~transformers.BitsAndBytesConfig`], *optional*):
Quantization configuration used when loading the model from a model identifier. Combine with `peft_config`
for QLoRA training. Ignored if the model is already instantiated.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
tools (list of `Callable`, *optional*):
Expand Down Expand Up @@ -280,6 +284,7 @@ def __init__(
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
quantization_config: "BitsAndBytesConfig | None" = None,
peft_config: "PeftConfig | None" = None,
tools: list[Callable] | None = None,
rollout_func: RolloutFunc | None = None,
Expand All @@ -294,6 +299,13 @@ def __init__(
# Model
if isinstance(model, str):
model_init_kwargs = args.model_init_kwargs or {}
if quantization_config is not None:
if "quantization_config" in model_init_kwargs:
raise ValueError(
"You set `quantization_config` both as a trainer argument and in `args.model_init_kwargs`. "
"Please set it in only one place, preferably as a trainer argument."
)
model_init_kwargs["quantization_config"] = quantization_config
Comment thread
cursor[bot] marked this conversation as resolved.
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
Expand All @@ -305,6 +317,11 @@ def __init__(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if quantization_config is not None:
logger.warning(
"You passed `quantization_config` to the trainer, but your model is already instantiated. The "
"`quantization_config` will be ignored."
)
# Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
_is_quantized_model = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)

Expand Down
18 changes: 18 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
BitsAndBytesConfig,
DataCollator,
PreTrainedModel,
PreTrainedTokenizerBase,
Expand Down Expand Up @@ -309,6 +310,10 @@ class RewardTrainer(_BaseTrainer):
by this function will be reflected in the predictions received by `compute_metrics`.

Note that the labels (second parameter) will be `None` if the dataset does not have them.
quantization_config ([`~transformers.BitsAndBytesConfig`], *optional*):
Quantization configuration used when loading the model from a model identifier. Combine with `peft_config`
for QLoRA training. Ignored if the model is already instantiated, or if `quantization_config` is also set
in `args.model_init_kwargs`.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
Expand All @@ -332,6 +337,7 @@ def __init__(
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
quantization_config: "BitsAndBytesConfig | None" = None,
peft_config: "PeftConfig | None" = None,
):
# Args
Expand Down Expand Up @@ -359,6 +365,13 @@ def __init__(
set_seed(args.seed)
if isinstance(model, str):
model_init_kwargs = args.model_init_kwargs or {}
if quantization_config is not None:
if "quantization_config" in model_init_kwargs:
raise ValueError(
"You set `quantization_config` both as a trainer argument and in `args.model_init_kwargs`. "
"Please set it in only one place, preferably as a trainer argument."
)
model_init_kwargs["quantization_config"] = quantization_config
# Distributed training requires device_map=None ("auto" fails)
if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]:
model_init_kwargs["device_map"] = None
Expand All @@ -372,6 +385,11 @@ def __init__(
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if quantization_config is not None:
logger.warning(
"You passed `quantization_config` to the trainer, but your model is already instantiated. The "
"`quantization_config` will be ignored."
)
# Validate that the model has num_labels = 1 (required for reward models)
if getattr(model.config, "num_labels", None) != 1:
raise ValueError(
Expand Down
Loading
Loading