Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@


if is_peft_available():
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, PromptTuningConfig, get_peft_model
from peft.utils import TaskType


class TestDataCollatorForPreference(TrlTestCase):
Expand Down Expand Up @@ -789,21 +790,63 @@ def test_train_with_liger(self):

@require_liger_kernel
@require_peft
def test_init_fails_with_peft_and_liger(self):
def test_liger_kernel_with_peft_lm_head_raises(self):
# The Liger fused DPO loss reads `lm_head.weight` directly, so a LoRA adapter on `lm_head` is silently
# ignored and never trained. The trainer must fail fast instead of training a silently-frozen head.
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
training_args = DPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none")
with pytest.raises(ValueError, match="lm_head"):
DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["q_proj", "v_proj", "lm_head"]),
)

@require_liger_kernel
@require_peft
def test_liger_kernel_with_peft_trains(self):
# A LoRA adapter that does not target lm_head leaves the head as a plain Linear, so Liger reads the real
# weight. Verify the full PEFT+Liger path actually trains (peft params change, base params stay frozen).
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32")
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
training_args = DPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
use_liger_kernel=True,
report_to="none",
)
trainer = DPOTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]),
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names:
torch.testing.assert_close(param, new_param, msg=f"Parameter {n} has changed.")
elif "base_layer" not in n:
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

with pytest.raises(NotImplementedError, match="Liger DPO loss is not implemented for PEFT models."):
@require_liger_kernel
@require_peft
def test_liger_kernel_with_peft_prompt_learning_raises(self):
# Prompt-learning methods inject virtual tokens via PeftModel.forward(), which the Liger DPO loss bypasses.
# The trainer must fail fast to avoid computing the loss on the wrong (truncated) sequence.
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
training_args = DPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none")
with pytest.raises(ValueError, match="prompt-learning"):
DPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
peft_config=PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8),
)

def test_train_with_iterable_dataset(self):
Expand Down
60 changes: 54 additions & 6 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@

if is_peft_available():
import peft
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from peft import LoraConfig, PeftConfig, PeftModel, PromptLearningConfig, get_peft_model
from peft.tuners.tuners_utils import BaseTunerLayer


if is_liger_kernel_available():
Expand Down Expand Up @@ -753,7 +754,31 @@ def __init__(
"`precompute_ref_log_probs` or set `use_liger_kernel` to False."
)
if is_peft_model(model):
raise NotImplementedError("Liger DPO loss is not implemented for PEFT models.")
# The Liger fused DPO loss multiplies the hidden states by `lm_head.weight` directly. When the LM head
# is targeted by a PEFT adapter (`"lm_head"` in `target_modules`), `lm_head.weight` is the frozen base
# weight and the trainable adapter parameters live in separate submodules that Liger never sees. The
# head adapter would silently receive no gradient, so the model trains as if `lm_head` were frozen.
# Fail loudly rather than train a silently-frozen head.
output_embeddings = model.get_output_embeddings()
if isinstance(output_embeddings, BaseTunerLayer):
raise ValueError(
"`use_liger_kernel=True` is incompatible with applying a PEFT adapter to `lm_head`. The Liger "
"fused DPO loss reads `lm_head.weight` directly, so the adapter on the head is ignored and "
"never trained. Either remove `'lm_head'` from your `target_modules`, or set "
"`use_liger_kernel=False`."
)
Comment on lines +762 to +769

@albertvillanova albertvillanova Jun 26, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'm implementing the guard differently. Additionally, I'm checking if other trainers need this guard as well.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Prompt-learning methods (PromptTuning, PrefixTuning, P-Tuning) inject virtual tokens via
# `PeftModel.forward()`. The Liger DPO loss bypasses `PeftModel.forward()` by calling the backbone
# directly, so virtual tokens are never prepended and the loss is computed on the wrong sequence.
# Fail loudly rather than train on a silently corrupted input.
if any(isinstance(cfg, PromptLearningConfig) for cfg in model.peft_config.values()):
raise ValueError(
"`use_liger_kernel=True` is incompatible with prompt-learning PEFT methods (PromptTuning, "
"PrefixTuning, P-Tuning). The Liger DPO loss bypasses `PeftModel.forward()` by calling the "
"backbone directly, so virtual tokens are never prepended and the loss is computed on the "
"wrong sequence. Use a weight-based adapter such as LoRA instead, or set "
"`use_liger_kernel=False`."
)

# Dataset
# Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly
Expand Down Expand Up @@ -1148,15 +1173,18 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
model_kwargs = {k: v for k, v in inputs.items() if k not in _non_model_keys}
model_kwargs["use_cache"] = False

if is_peft_model(model):
model = model.base_model.model

# `base_model` gives the backbone model (skipping `lm_head`) — text decoder for LMs, multimodal wrapper for
# VLMs (so vision-token injection runs before the text decoder). `get_decoder()` won't do: on VLMs it
# returns just the text stack and feeds image-placeholder IDs through it.
# Pre-5.0 transformers VLMs set `base_model_prefix = ""` so `base_model is self` (re-runs `lm_head`).
# Fall back to `.model` there.
if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"):
backbone, ref_backbone = model.model, self.ref_model.model
backbone = model.model
else:
backbone, ref_backbone = model.base_model, self.ref_model.base_model
backbone = model.base_model

outputs = backbone(**model_kwargs)
hidden_states = outputs.last_hidden_state[:, :-1].contiguous()
Expand All @@ -1165,8 +1193,28 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
bias = lm_head.bias

with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs):
ref_outputs = ref_backbone(**model_kwargs)
ref_lm_head = self.ref_model.get_output_embeddings()
if self.ref_model is None:
# PEFT model with no explicit reference model: recover reference behaviour by disabling / switching to
# the frozen "ref" adapter, exactly as _compute_loss does for logit-based reference computation.
model_unwrapped = self.accelerator.unwrap_model(self.model)
with use_adapter(
model_unwrapped, adapter_name="ref" if "ref" in model_unwrapped.peft_config else None
):
ref_model_inner = model_unwrapped.base_model.model
if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"):
ref_backbone = ref_model_inner.model
else:
ref_backbone = ref_model_inner.base_model
ref_outputs = ref_backbone(**model_kwargs)
ref_lm_head = model_unwrapped.get_output_embeddings()
else:
ref_model_inner = self.ref_model.base_model.model if is_peft_model(self.ref_model) else self.ref_model
if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"):
ref_backbone = ref_model_inner.model
else:
ref_backbone = ref_model_inner.base_model
ref_outputs = ref_backbone(**model_kwargs)
ref_lm_head = self.ref_model.get_output_embeddings()
ref_hidden_states = ref_outputs.last_hidden_state[:, :-1].contiguous()
ref_weight = ref_lm_head.weight
ref_bias = ref_lm_head.bias
Expand Down
Loading