Skip to content
Open
16 changes: 6 additions & 10 deletions docs/source/peft_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,11 +449,13 @@ python trl/scripts/sft.py \

#### Python Example

Pass the `quantization_config` directly to the trainer alongside `peft_config` — the trainer loads and quantizes the model for you. The same `quantization_config` argument is available on [`SFTTrainer`], [`DPOTrainer`], [`GRPOTrainer`], and [`RLOOTrainer`].

```python
import torch

from peft import LoraConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers import BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

# Configure 4-bit quantization
Expand All @@ -464,13 +466,6 @@ bnb_config = BitsAndBytesConfig(
bnb_4bit_use_double_quant=True,
)

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto",
)

# Configure LoRA
peft_config = LoraConfig(
r=32,
Expand All @@ -486,11 +481,12 @@ training_args = SFTConfig(
...
)

# Create trainer with PEFT config
# Create trainer with quantization and PEFT config
trainer = SFTTrainer(
model=model,
model="meta-llama/Llama-2-7b-hf",
args=training_args,
train_dataset=dataset,
quantization_config=bnb_config,
peft_config=peft_config,
)

Expand Down
Loading
Loading