Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@


if is_peft_available():
from peft import LoraConfig, PeftModel, get_peft_model
from peft import LoraConfig, PeftModel, PromptTuningConfig, get_peft_model
from peft.utils import TaskType


def multiply_tool(a: int, b: int) -> int:
Expand Down Expand Up @@ -627,6 +628,22 @@ def test_liger_kernel_with_peft_modules_to_save_lm_head_allowed(self):
with pytest.raises(ImportError):
GRPOTrainer(**kwargs)

@require_peft
def test_liger_kernel_with_peft_prompt_learning_raises(self):
# Prompt-learning methods inject virtual tokens via PeftModel.forward(), which the Liger GRPO loss bypasses.
# The trainer must fail fast to avoid computing the loss on the wrong (truncated) sequence.
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32")
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(output_dir=self.tmp_dir, use_liger_kernel=True, report_to="none")
with pytest.raises(ValueError, match="prompt-learning"):
GRPOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
peft_config=PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=8),
)

@require_peft
def test_train_peft_model(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", dtype="float32")
Expand Down
14 changes: 13 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@

if is_peft_available():
import peft
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from peft import LoraConfig, PeftConfig, PeftModel, PromptLearningConfig, get_peft_model
from peft.tuners.tuners_utils import BaseTunerLayer

if is_liger_kernel_available():
Expand Down Expand Up @@ -656,6 +656,18 @@ def __init__(
"fused GRPO loss reads `lm_head.weight` directly, so the adapter on the head is ignored and never "
"trained. Either remove `'lm_head'` from your `target_modules`, or set `use_liger_kernel=False`."
)
# Prompt-learning methods (PromptTuning, PrefixTuning, P-Tuning) inject virtual tokens via
# `PeftModel.forward()`. The Liger GRPO loss bypasses `PeftModel.forward()` by calling the backbone
# directly, so virtual tokens are never prepended and the loss is computed on the wrong sequence.
# Fail loudly rather than train on a silently corrupted input.
if any(isinstance(cfg, PromptLearningConfig) for cfg in model.peft_config.values()):
raise ValueError(
"`use_liger_kernel=True` is incompatible with prompt-learning PEFT methods (PromptTuning, "
"PrefixTuning, P-Tuning). The Liger GRPO loss bypasses `PeftModel.forward()` by calling the "
"backbone directly, so virtual tokens are never prepended and the loss is computed on the "
"wrong sequence. Use a weight-based adapter such as LoRA instead, or set "
"`use_liger_kernel=False`."
)
self.mask_truncated_completions = args.mask_truncated_completions
self.top_entropy_quantile = args.top_entropy_quantile
if self.use_liger_kernel and self.top_entropy_quantile < 1.0:
Expand Down
Loading