Support PEFT with Liger in DPO#6159
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Pull request overview
Enables using use_liger_kernel=True with PEFT (e.g., LoRA) in DPOTrainer by replacing the previous blanket rejection with targeted compatibility checks and by extending the Liger reference-computation path to work when ref_model is not instantiated for PEFT.
Changes:
- Replace the PEFT+Liger blanket
NotImplementedErrorwith a targeted validation that rejects PEFT adapters applied tolm_head(to avoid silently ignoring head adapters in the fused loss path). - Update
_compute_loss_ligerto unwrap PEFT models before backbone execution and to compute reference hidden states via adapter disabling/switching whenref_model is None. - Remove the test that asserted PEFT+Liger init always fails.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
trl/trainer/dpo_trainer.py |
Adds PEFT+Liger compatibility checks and extends Liger loss computation to support PEFT reference behavior. |
tests/test_dpo_trainer.py |
Removes the outdated PEFT+Liger init-failure test (needs replacement coverage for the new behavior). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| peft_config=LoraConfig(), | ||
| ) | ||
|
|
||
| def test_train_with_iterable_dataset(self): |
There was a problem hiding this comment.
I'm adding tests.
| output_embeddings = model.get_output_embeddings() | ||
| if isinstance(output_embeddings, BaseTunerLayer): | ||
| raise ValueError( | ||
| "`use_liger_kernel=True` is incompatible with applying a PEFT adapter to `lm_head`. The Liger " | ||
| "fused DPO loss reads `lm_head.weight` directly, so the adapter on the head is ignored and " | ||
| "never trained. Either remove `'lm_head'` from your `target_modules`, or set " | ||
| "`use_liger_kernel=False`." | ||
| ) |
There was a problem hiding this comment.
OK, I'm implementing the guard differently. Additionally, I'm checking if other trainers need this guard as well.
There was a problem hiding this comment.
Support PEFT models with use_liger_kernel=True in
DPOTrainer.This PR enables PEFT models (e.g. LoRA) to be used with
use_liger_kernel=TrueinDPOTrainer, lifting the blanket NotImplementedError that previously blocked all PEFT+Liger combinations.Motivation
The Liger fused DPO loss bypasses the model's
forward()and multiplies hidden states bylm_head.weightdirectly. The previous guard raisedNotImplementedErrorfor any PEFT model, but this was too broad: the only genuinely incompatible case is when lm_head itself is wrapped by a PEFT adapter (e.g. "lm_head" in target_modules), because then lm_head.weight is the frozen base weight and the adapter delta is silently ignored. When lm_head is not adapted, PEFT+Liger works correctly.Changes
Compatibility and Error Handling
use_liger_kernel=Truewhen a PEFT adapter is applied tolm_head, raising a clear error if this unsupported configuration is detected. This avoids silent failures where the head adapter would not be trained.BaseTunerLayerfrompeft.tuners.tuners_utilsto enable the above compatibility check.Model Unwrapping and Reference Handling
Note
Medium Risk
Changes core DPO training loss paths for PEFT+Liger; incorrect reference or backbone unwrapping could skew gradients, though tests and explicit guards reduce silent failure risk.
Overview
DPOTrainernow allowsuse_liger_kernel=Truewith PEFT when the setup is actually safe, instead of rejecting every PEFT model.The old blanket
NotImplementedErroris gone. Init-timeValueErrorchecks blocklm_headintarget_modules(Liger reads frozenlm_head.weightand would skip head LoRA) and prompt-learning PEFT (Liger bypassesPeftModel.forward(), so virtual tokens never apply)._compute_loss_ligerunwraps PEFT viamodel.base_model.model, and when there is no separateref_modelit builds reference hidden states using the sameuse_adapter/ ref-adapter path as the standard DPO loss.Tests cover the new error cases and an end-to-end LoRA (no
lm_head) + Liger training run.Reviewed by Cursor Bugbot for commit 5cb1748. Bugbot is set up for automated code reviews on this repo. Configure here.