Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,25 +788,6 @@ def test_train_with_liger(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@require_liger_kernel
@require_peft
def test_init_fails_with_peft_and_liger(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")

training_args = DPOConfig(
output_dir=self.tmp_dir,
use_liger_kernel=True,
report_to="none",
)

with pytest.raises(NotImplementedError, match="Liger DPO loss is not implemented for PEFT models."):
DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
)

def test_train_with_iterable_dataset(self):

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding tests.

dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train", streaming=True)

Expand Down
46 changes: 41 additions & 5 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
if is_peft_available():
import peft
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from peft.tuners.tuners_utils import BaseTunerLayer


if is_liger_kernel_available():
Expand Down Expand Up @@ -753,7 +754,19 @@ def __init__(
"`precompute_ref_log_probs` or set `use_liger_kernel` to False."
)
if is_peft_model(model):
raise NotImplementedError("Liger DPO loss is not implemented for PEFT models.")
# The Liger fused DPO loss multiplies the hidden states by `lm_head.weight` directly. When the LM head
# is targeted by a PEFT adapter (`"lm_head"` in `target_modules`), `lm_head.weight` is the frozen base
# weight and the trainable adapter parameters live in separate submodules that Liger never sees. The
# head adapter would silently receive no gradient, so the model trains as if `lm_head` were frozen.
# Fail loudly rather than train a silently-frozen head.
output_embeddings = model.get_output_embeddings()
if isinstance(output_embeddings, BaseTunerLayer):
raise ValueError(
"`use_liger_kernel=True` is incompatible with applying a PEFT adapter to `lm_head`. The Liger "
"fused DPO loss reads `lm_head.weight` directly, so the adapter on the head is ignored and "
"never trained. Either remove `'lm_head'` from your `target_modules`, or set "
"`use_liger_kernel=False`."
)
Comment on lines +762 to +769

@albertvillanova albertvillanova Jun 26, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'm implementing the guard differently. Additionally, I'm checking if other trainers need this guard as well.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Dataset
# Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly
Expand Down Expand Up @@ -1148,15 +1161,18 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
model_kwargs = {k: v for k, v in inputs.items() if k not in _non_model_keys}
model_kwargs["use_cache"] = False

if is_peft_model(model):
model = model.base_model.model

# `base_model` gives the backbone model (skipping `lm_head`) — text decoder for LMs, multimodal wrapper for
# VLMs (so vision-token injection runs before the text decoder). `get_decoder()` won't do: on VLMs it
# returns just the text stack and feeds image-placeholder IDs through it.
# Pre-5.0 transformers VLMs set `base_model_prefix = ""` so `base_model is self` (re-runs `lm_head`).
# Fall back to `.model` there.
if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"):
backbone, ref_backbone = model.model, self.ref_model.model
backbone = model.model
else:
backbone, ref_backbone = model.base_model, self.ref_model.base_model
backbone = model.base_model

outputs = backbone(**model_kwargs)
hidden_states = outputs.last_hidden_state[:, :-1].contiguous()
Expand All @@ -1165,8 +1181,28 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
bias = lm_head.bias

with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
ref_outputs = ref_backbone(**model_kwargs)
ref_lm_head = self.ref_model.get_output_embeddings()
if self.ref_model is None:
# PEFT model with no explicit reference model: recover reference behaviour by disabling / switching to
# the frozen "ref" adapter, exactly as _compute_loss does for logit-based reference computation.
model_unwrapped = self.accelerator.unwrap_model(self.model)
with use_adapter(
model_unwrapped, adapter_name="ref" if "ref" in model_unwrapped.peft_config else None
):
ref_model_inner = model_unwrapped.base_model.model
if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"):
ref_backbone = ref_model_inner.model
else:
ref_backbone = ref_model_inner.base_model
ref_outputs = ref_backbone(**model_kwargs)
ref_lm_head = model_unwrapped.get_output_embeddings()
else:
ref_model_inner = self.ref_model.base_model.model if is_peft_model(self.ref_model) else self.ref_model
if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"):
ref_backbone = ref_model_inner.model
else:
ref_backbone = ref_model_inner.base_model
ref_outputs = ref_backbone(**model_kwargs)
ref_lm_head = self.ref_model.get_output_embeddings()
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1].contiguous()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias
Expand Down
Loading