diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index ca5fa2dcd82..ef3deaad10e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -36,7 +36,8 @@ if is_peft_available(): - from peft import LoraConfig, get_peft_model + from peft import LoraConfig, PromptTuningConfig, get_peft_model + from peft.utils import TaskType class TestDataCollatorForPreference(TrlTestCase): @@ -789,21 +790,63 @@ def test_train_with_liger(self): @require_liger_kernel @require_peft - def test_init_fails_with_peft_and_liger(self): + def test_liger_kernel_with_peft_lm_head_raises(self): + # The Liger fused DPO loss reads `lm_head.weight` directly, so a LoRA adapter on `lm_head` is silently + # ignored and never trained. The trainer must fail fast instead of training a silently-frozen head. dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + training_args = DPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none") + with pytest.raises(ValueError, match="lm_head"): + DPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(target_modules=["q_proj", "v_proj", "lm_head"]), + ) + @require_liger_kernel + @require_peft + def test_liger_kernel_with_peft_trains(self): + # A LoRA adapter that does not target lm_head leaves the head as a plain Linear, so Liger reads the real + # weight. Verify the full PEFT+Liger path actually trains (peft params change, base params stay frozen). + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id, dtype="float32") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") training_args = DPOConfig( output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates use_liger_kernel=True, report_to="none", ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + assert trainer.state.log_history[-1]["train_loss"] is not None + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: + torch.testing.assert_close(param, new_param, msg=f"Parameter {n} has changed.") + elif "base_layer" not in n: + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - with pytest.raises(NotImplementedError, match="Liger DPO loss is not implemented for PEFT models."): + @require_liger_kernel + @require_peft + def test_liger_kernel_with_peft_prompt_learning_raises(self): + # Prompt-learning methods inject virtual tokens via PeftModel.forward(), which the Liger DPO loss bypasses. + # The trainer must fail fast to avoid computing the loss on the wrong (truncated) sequence. + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + training_args = DPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none") + with pytest.raises(ValueError, match="prompt-learning"): DPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset, - peft_config=LoraConfig(), + peft_config=PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8), ) def test_train_with_iterable_dataset(self): diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index cd91d913e01..a6d06f148b3 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -65,7 +65,8 @@ if is_peft_available(): import peft - from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model + from peft import LoraConfig, PeftConfig, PeftModel, PromptLearningConfig, get_peft_model + from peft.tuners.tuners_utils import BaseTunerLayer if is_liger_kernel_available(): @@ -753,7 +754,31 @@ def __init__( "`precompute_ref_log_probs` or set `use_liger_kernel` to False." ) if is_peft_model(model): - raise NotImplementedError("Liger DPO loss is not implemented for PEFT models.") + # The Liger fused DPO loss multiplies the hidden states by `lm_head.weight` directly. When the LM head + # is targeted by a PEFT adapter (`"lm_head"` in `target_modules`), `lm_head.weight` is the frozen base + # weight and the trainable adapter parameters live in separate submodules that Liger never sees. The + # head adapter would silently receive no gradient, so the model trains as if `lm_head` were frozen. + # Fail loudly rather than train a silently-frozen head. + output_embeddings = model.get_output_embeddings() + if isinstance(output_embeddings, BaseTunerLayer): + raise ValueError( + "`use_liger_kernel=True` is incompatible with applying a PEFT adapter to `lm_head`. The Liger " + "fused DPO loss reads `lm_head.weight` directly, so the adapter on the head is ignored and " + "never trained. Either remove `'lm_head'` from your `target_modules`, or set " + "`use_liger_kernel=False`." + ) + # Prompt-learning methods (PromptTuning, PrefixTuning, P-Tuning) inject virtual tokens via + # `PeftModel.forward()`. The Liger DPO loss bypasses `PeftModel.forward()` by calling the backbone + # directly, so virtual tokens are never prepended and the loss is computed on the wrong sequence. + # Fail loudly rather than train on a silently corrupted input. + if any(isinstance(cfg, PromptLearningConfig) for cfg in model.peft_config.values()): + raise ValueError( + "`use_liger_kernel=True` is incompatible with prompt-learning PEFT methods (PromptTuning, " + "PrefixTuning, P-Tuning). The Liger DPO loss bypasses `PeftModel.forward()` by calling the " + "backbone directly, so virtual tokens are never prepended and the loss is computed on the " + "wrong sequence. Use a weight-based adapter such as LoRA instead, or set " + "`use_liger_kernel=False`." + ) # Dataset # Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly @@ -1148,15 +1173,18 @@ def _compute_loss_liger(self, model, inputs, return_outputs): model_kwargs = {k: v for k, v in inputs.items() if k not in _non_model_keys} model_kwargs["use_cache"] = False + if is_peft_model(model): + model = model.base_model.model + # `base_model` gives the backbone model (skipping `lm_head`) — text decoder for LMs, multimodal wrapper for # VLMs (so vision-token injection runs before the text decoder). `get_decoder()` won't do: on VLMs it # returns just the text stack and feeds image-placeholder IDs through it. # Pre-5.0 transformers VLMs set `base_model_prefix = ""` so `base_model is self` (re-runs `lm_head`). # Fall back to `.model` there. if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"): - backbone, ref_backbone = model.model, self.ref_model.model + backbone = model.model else: - backbone, ref_backbone = model.base_model, self.ref_model.base_model + backbone = model.base_model outputs = backbone(**model_kwargs) hidden_states = outputs.last_hidden_state[:, :-1].contiguous() @@ -1165,8 +1193,28 @@ def _compute_loss_liger(self, model, inputs, return_outputs): bias = lm_head.bias with torch.no_grad(), disable_gradient_checkpointing(self.model, self.args.gradient_checkpointing_kwargs): - ref_outputs = ref_backbone(**model_kwargs) - ref_lm_head = self.ref_model.get_output_embeddings() + if self.ref_model is None: + # PEFT model with no explicit reference model: recover reference behaviour by disabling / switching to + # the frozen "ref" adapter, exactly as _compute_loss does for logit-based reference computation. + model_unwrapped = self.accelerator.unwrap_model(self.model) + with use_adapter( + model_unwrapped, adapter_name="ref" if "ref" in model_unwrapped.peft_config else None + ): + ref_model_inner = model_unwrapped.base_model.model + if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"): + ref_backbone = ref_model_inner.model + else: + ref_backbone = ref_model_inner.base_model + ref_outputs = ref_backbone(**model_kwargs) + ref_lm_head = model_unwrapped.get_output_embeddings() + else: + ref_model_inner = self.ref_model.base_model.model if is_peft_model(self.ref_model) else self.ref_model + if self._is_vlm and Version(transformers.__version__) < Version("5.0.0"): + ref_backbone = ref_model_inner.model + else: + ref_backbone = ref_model_inner.base_model + ref_outputs = ref_backbone(**model_kwargs) + ref_lm_head = self.ref_model.get_output_embeddings() ref_hidden_states = ref_outputs.last_hidden_state[:, :-1].contiguous() ref_weight = ref_lm_head.weight ref_bias = ref_lm_head.bias