Skip to content

Add quantization_config trainer argument (streamline QLoRA)#6157

Open
qgallouedec wants to merge 8 commits into
mainfrom
native-quantization-config
Open

Add quantization_config trainer argument (streamline QLoRA)#6157
qgallouedec wants to merge 8 commits into
mainfrom
native-quantization-config

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Jun 24, 2026

Copy link
Copy Markdown
Member

Adds a quantization_config argument to SFTTrainer, DPOTrainer, GRPOTrainer, RLOOTrainer, and RewardTrainer, so QLoRA no longer requires reaching into model_init_kwargs (or worse, manual model loading)

After:

SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    peft_config=LoraConfig(),
    train_dataset=dataset,
)

Compare with before (many ressources are written like this!):

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    device_map="auto",
)
SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=LoraConfig(),
)

Before (the "right" way, but not very popular):

SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    args=SFTConfig(model_init_kwargs={"quantization_config": BitsAndBytesConfig(load_in_4bit=True)}),
    peft_config=LoraConfig(),
    train_dataset=dataset,
)

It sits next to peft_config (the other non-serializable QLoRA ingredient), flows into from_pretrained, and raises if also set in args.model_init_kwargs.

Changes

  • New quantization_config arg on the five trainers above (+ docstrings).
  • The trl/scripts/{sft,dpo,grpo,rloo,reward}.py CLIs now pass it directly instead of injecting into model_init_kwargs.
  • This drops the redundant model_init_kwargs["device_map"] = get_kbit_device_map() line: verified on 8×H100 that QLoRA trains identically with and without it, across transformers 4.56.2 (min supported) and 5.13; distributed runs override device_map to None anyway, and single-process runs auto-place quantized weights on the current CUDA device. See Remove redundant get_kbit_device_map() #6158
  • Updated the QLoRA example in docs/source/peft_integration.md.

Note

Medium Risk
Touches model loading for all major TRL trainers and reference-model paths; behavior change for QLoRA users but scoped to optional loading kwargs with explicit conflict checks.

Overview
Adds a quantization_config trainer argument (alongside peft_config) on SFTTrainer, DPOTrainer, GRPOTrainer, RLOOTrainer, and RewardTrainer, so QLoRA can pass a model id string and let the trainer load/quantize via from_pretrained instead of pre-loading with AutoModelForCausalLM.

When the model is loaded from a string, the trainer merges quantization_config into model_init_kwargs (and the same for reference models where applicable), errors if it is also set in args.model_init_kwargs, and warns if a pre-instantiated model is passed. CLI entrypoints and example scripts now pass get_quantization_config(model_args) directly to the trainer and no longer inject quantization_config / get_kbit_device_map() into model_init_kwargs.

Docs and Colab notebooks are updated to the new pattern (model id + quantization_config on the trainer, model_init_kwargs for attn/dtype where needed).

Reviewed by Cursor Bugbot for commit 7be97c4. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread trl/trainer/dpo_trainer.py
@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 45d6a2decd

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
quantization_config: "BitsAndBytesConfig | None" = None,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve positional peft_config compatibility

Adding quantization_config before the existing peft_config parameter shifts any current positional peft_config argument into quantization_config because this public constructor is not keyword-only. In existing calls that pass peft_config positionally, a model id will forward a PeftConfig object to from_pretrained(..., quantization_config=...) and fail, while an already-instantiated model will ignore it and train without the adapter; the same signature insertion appears in the other updated trainers. Put the new argument after peft_config or otherwise preserve the old positional layout.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although not specifically disallowed, it would be very surprising that peft_config is used as positional arg

Comment thread trl/trainer/grpo_trainer.py

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

There are 3 total unresolved issues (including 2 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 0bb426c. Configure here.

Comment thread trl/trainer/sft_trainer.py

@sergiopaniego sergiopaniego left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the example scripts/notebooks in the examples/ folder should also be reviewed and updated

@qgallouedec

Copy link
Copy Markdown
Member Author

right @sergiopaniego , updated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants