From a0827b0a1d54279019e8b998b74f332e4a5b16ac Mon Sep 17 00:00:00 2001 From: "remyx-ai[bot]" <289541483+remyx-ai[bot]@users.noreply.github.com> Date: Wed, 24 Jun 2026 09:13:56 -0700 Subject: [PATCH 1/2] =?UTF-8?q?feat(`grpo=5Ftrainer.py`):=20STARE=20?= =?UTF-8?q?=E2=80=94=20Surprisal-guided=20Token-Level=20Advantage=20Reweig?= =?UTF-8?q?hting?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the STARE policy loss (paper ยง4.1-4.3) as a new `loss_type` on GRPOTrainer, following the same dispatch convention as `dr_grpo`, `sapo`, `luspo`, `vespo`, etc. Paper: https://huggingface.co/papers/2606.19236 Official implementation: https://github.com/hp-luo/STARE (verl/trainer/ppo/core_algos.py::compute_policy_loss_stare + STAREController, Apache-2.0) The STARE objective wraps the GRPO dual-clip surrogate with a per-token PG-loss reweighting factor omega: - Partition valid response tokens by sign(A) into T+ / T- - Select top stare_top_p_ratio of each by surprisal s = -log pi_theta(o) to form the entropy-critical sets L+ / L- - omega = W on L+ (variant O1, default); omega = W on L+ and M on L- (variant C2 dual-sided); 1 otherwise - Closed-loop gate: weights apply only when batch-mean policy entropy < stare_target_entropy; otherwise revert to 1 (degrades to GRPO) Paper defaults are wired in (top_p=0.1, W=1.1, M=0.9, target_entropy=0.3). STARE aggregation follows the bnpo branch (matches verl's `token-mean` default for compute_policy_loss_stare). Co-Authored-By: smellslikeml <9044907+smellslikeml@users.noreply.github.com> Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/paper_index.md | 34 +++++++++ tests/test_grpo_trainer.py | 146 +++++++++++++++++++++++++++++++++++- trl/trainer/grpo_config.py | 63 ++++++++++++++++ trl/trainer/grpo_trainer.py | 130 +++++++++++++++++++++++++++++++- 4 files changed, 370 insertions(+), 3 deletions(-) diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 533bb98a8b4..b13c8431a6a 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -692,6 +692,40 @@ training_args = GRPOConfig( ``` +### STARE: Surprisal-Guided Token-Level Advantage Reweighting for Policy Entropy Stability + +**๐Ÿ“œ Paper**: https://huggingface.co/papers/2606.19236 + +STARE addresses policy entropy collapse in PPO/GRPO-style RL on language models. A first-order analysis of per-token entropy dynamics shows that the change in entropy is proportional to the product of a token's advantage and an entropy-sensitivity term that grows with the token's surprisal \\( s_{i,t} = -\log \pi_\theta(o_{i,t}) \\) โ€” so updates on high-surprisal tokens dominate how fast the policy loses (or keeps) exploration. STARE identifies the *entropy-critical token subset* via a top-fraction-by-surprisal selection within each advantage-sign partition and reweights the per-token PG loss of those tokens, gated by a closed-loop target-entropy signal. + +The STARE objective wraps the standard GRPO dual-clip surrogate with a token-level reweighting factor \\( \omega_{i,t} \\): + +$$ +\mathcal{J}_{\text{STARE}}(\theta) = \mathbb{E}_{q, \{o_i\}} \left[ \frac{1}{\sum_i |o_i|} \sum_{i=1}^{G} \sum_{t=1}^{|o_i|} \omega_{i,t} \cdot \min\left( w_{i,t}(\theta) \hat{A}_{i,t},\ \operatorname{clip}(w_{i,t}(\theta), 1-\epsilon, 1+\epsilon)\, \hat{A}_{i,t} \right) \right] +$$ + +with \\( w_{i,t}(\theta) = \pi_\theta(o_{i,t} \mid q, o_{i, 1 & \text{if } (i,t) \in L^+ \text{ (positive-advantage critical tokens)} \\ M < 1 & \text{if } (i,t) \in L^- \text{ and variant } = C2 \\ 1 & \text{otherwise} \end{cases} +$$ + +active only when the closed-loop gate is open: \\( \mathbb{1}[\bar{H}_k < H_{\text{tgt}}] \\), where \\( \bar{H}_k \\) is the batch-mean policy entropy and \\( H_{\text{tgt}} \\) is the target floor. When the gate stays closed, all weights revert to 1 and STARE degrades to vanilla GRPO. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + loss_type="stare", + stare_variant="O1", # "O1" (one-sided amplification, default) or "C2" (dual-sided regulation) + stare_top_p_ratio=0.1, # paper Section 4.1 default + stare_reweight_w=1.1, # multiplicative weight for L+ (positive-advantage critical tokens) + stare_reweight_m=0.9, # multiplicative weight for L- (only active when variant="C2") + stare_target_entropy=0.3, # closed-loop gate floor (paper Section 4.3) +) +``` + + ### Rethinking the Trust Region in LLM Reinforcement Learning **๐Ÿ“œ Paper**: https://huggingface.co/papers/2602.04879 diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 98d61d581fa..8e6798e0599 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -159,6 +159,150 @@ def test_compute_entropy_all_masked(self): torch.testing.assert_close(entropy_mask, expected_mask) +class TestComputeStareTokenWeights(TrlTestCase): + """Math properties of STARE per-token PG-loss reweighting (paper Sections 4.1-4.3).""" + + def _logps(self, probs): + return torch.log(torch.tensor(probs, dtype=torch.float32)) + + def _low_entropy(self, shape): + # Far below the default target_entropy=0.3 โ†’ closed-loop gate opens. + return torch.full(shape, 0.01) + + def _high_entropy(self, shape): + # Above target_entropy=0.3 โ†’ gate stays closed. + return torch.full(shape, 0.5) + + def test_gate_closed_returns_unit_weights(self): + # When batch-mean entropy >= target, all weights revert to 1 and STARE degrades to vanilla GRPO. + logps = self._logps([[-0.1, -0.2, -3.0, -4.0]]) + advantages = torch.tensor([[1.0, 1.0, 1.0, 1.0]]) + mask = torch.ones_like(advantages) + weights, stats = GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._high_entropy(logps.shape), + mask=mask, + target_entropy=0.3, + ) + torch.testing.assert_close(weights, torch.ones_like(weights)) + assert stats["stare/gate_on"] == 0.0 + + def test_variant_o1_reweights_only_positive_critical_tokens(self): + # Two positive-advantage tokens; the higher-surprisal one is in L+ and gets reweight_w. + logps = self._logps([[-0.1, -3.0]]) # surprisal: [0.1, 3.0] โ†’ token 1 is critical + advantages = torch.tensor([[1.0, 1.0]]) + mask = torch.ones_like(advantages) + weights, stats = GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape), + mask=mask, + variant="O1", + top_p_ratio=0.5, # 50% โ†’ ceil(2*0.5)=1 critical token per partition + reweight_w=1.5, + reweight_m=0.9, + ) + torch.testing.assert_close(weights, torch.tensor([[1.0, 1.5]])) + assert stats["stare/gate_on"] == 1.0 + + def test_variant_o1_ignores_negative_advantage_tokens(self): + # O1 should leave negative-advantage tokens at weight 1 regardless of surprisal. + logps = self._logps([[-3.0, -3.0]]) + advantages = torch.tensor([[1.0, -1.0]]) # one positive, one negative + mask = torch.ones_like(advantages) + weights, _ = GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape), + mask=mask, + variant="O1", + top_p_ratio=1.0, + reweight_w=1.5, + reweight_m=0.5, + ) + # Positive-advantage token reweighted by W; negative-advantage token untouched. + torch.testing.assert_close(weights, torch.tensor([[1.5, 1.0]])) + + def test_variant_c2_reweights_both_partitions(self): + # C2 also damps negative-advantage critical tokens. + logps = self._logps([[-3.0, -3.0]]) + advantages = torch.tensor([[1.0, -1.0]]) + mask = torch.ones_like(advantages) + weights, _ = GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape), + mask=mask, + variant="C2", + top_p_ratio=1.0, + reweight_w=1.5, + reweight_m=0.5, + ) + torch.testing.assert_close(weights, torch.tensor([[1.5, 0.5]])) + + def test_mask_excludes_padded_tokens(self): + # Padded positions must be excluded from candidate sets even when they have high surprisal. + logps = self._logps([[-0.1, -3.0, -5.0]]) # token 2 has highest surprisal but is padded + advantages = torch.tensor([[1.0, 1.0, 1.0]]) + mask = torch.tensor([[1.0, 1.0, 0.0]]) # last token padded + weights, stats = GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape) * mask, + mask=mask, + variant="O1", + top_p_ratio=0.5, # ceil(2*0.5)=1 critical token out of 2 valid candidates + reweight_w=1.5, + ) + # Token 1 (surprisal 3.0) is the highest-surprisal *valid* candidate. + torch.testing.assert_close(weights, torch.tensor([[1.0, 1.5, 1.0]])) + # Padded token contributes 0 to the reweight-ratio denominator. + assert stats["stare/positive_reweight_ratio"] == 0.5 + + def test_top_p_ratio_selects_correct_count(self): + # 4 positive-advantage tokens, top_p_ratio=0.5 โ†’ ceil(4*0.5)=2 critical tokens. + logps = self._logps([[-0.1, -0.2, -3.0, -4.0]]) # tokens 2, 3 have highest surprisal + advantages = torch.tensor([[1.0, 1.0, 1.0, 1.0]]) + mask = torch.ones_like(advantages) + weights, _ = GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape), + mask=mask, + variant="O1", + top_p_ratio=0.5, + reweight_w=1.5, + ) + torch.testing.assert_close(weights, torch.tensor([[1.0, 1.0, 1.5, 1.5]])) + + def test_rejects_invalid_variant(self): + logps = self._logps([[-0.5]]) + advantages = torch.tensor([[1.0]]) + mask = torch.ones_like(advantages) + with pytest.raises(ValueError, match="Unsupported STARE variant"): + GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape), + mask=mask, + variant="X1", + ) + + def test_rejects_invalid_top_p_ratio(self): + logps = self._logps([[-0.5]]) + advantages = torch.tensor([[1.0]]) + mask = torch.ones_like(advantages) + with pytest.raises(ValueError, match="stare_top_p_ratio"): + GRPOTrainer.compute_stare_token_weights( + per_token_logps=logps, + advantages=advantages, + entropies=self._low_entropy(logps.shape), + mask=mask, + top_p_ratio=1.5, + ) + + class TestGRPORolloutDispatch: def _make_trainer(self): trainer = object.__new__(GRPOTrainer) @@ -393,7 +537,7 @@ def reward_func(completions, **kwargs): assert type(trainer.model).__name__ == "RemoteForCausalLM" @pytest.mark.parametrize("use_liger_kernel", [False, pytest.param(True, marks=require_liger_kernel)]) - @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"]) + @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo", "stare"]) def test_train_loss_types(self, loss_type, use_liger_kernel): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 5736b6c0ddc..cdd6d90b47f 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -212,6 +212,23 @@ class GRPOConfig(_BaseConfig): lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. Controls how aggressively we down-weight samples with high importance weights (when the importance sampling ratio > 1). + stare_variant (`str`, *optional*, defaults to `"O1"`): + STARE token-reweighting variant. `"O1"` (default, paper Section 4.2) amplifies positive-advantage + entropy-critical tokens only. `"C2"` additionally damps negative-advantage entropy-critical tokens. + Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236). + stare_top_p_ratio (`float`, *optional*, defaults to `0.1`): + Fraction of highest-surprisal tokens (within each advantage-sign partition) treated as entropy-critical + and reweighted by STARE. Paper Section 4.1 default is `0.1`. + stare_reweight_w (`float`, *optional*, defaults to `1.1`): + Multiplicative weight `W > 1` applied to the per-token PG loss of positive-advantage entropy-critical + tokens (the set L+). Paper Section 4.2 default is `1.1`. + stare_reweight_m (`float`, *optional*, defaults to `0.9`): + Multiplicative weight `M < 1` applied to the per-token PG loss of negative-advantage entropy-critical + tokens (the set L-). Only active when `stare_variant="C2"`. Paper Section 4.2 default is `0.9`. + stare_target_entropy (`float`, *optional*, defaults to `0.3`): + Target policy entropy floor for STARE's closed-loop gate (paper Section 4.3). The gate opens + (reweighting is active) when the batch-mean policy entropy falls below this value; otherwise weights + revert to 1 and STARE degrades to vanilla GRPO. Paper default is `0.3`. importance_sampling_level (`str`, *optional*, defaults to `"token"`): Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"` keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the @@ -272,6 +289,10 @@ class GRPOConfig(_BaseConfig): - `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in the [VESPO paper](https://huggingface.co/papers/2602.10693). + - `"stare"`: Surprisal-Guided Token-Level Advantage Reweighting for Policy Entropy Stability. Wraps the + GRPO dual-clip surrogate and reweights the per-token PG loss of entropy-critical tokens (the + highest-surprisal tokens within each advantage-sign partition), gated by a closed-loop target-entropy + signal. Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236). mask_truncated_completions (`bool`, *optional*, defaults to `False`): When enabled, truncated completions are excluded from the loss calculation, preventing them from being incorrectly penalized and introducing noise during training. According to the @@ -716,6 +737,44 @@ class GRPOConfig(_BaseConfig): "sampling ratio > 1)." }, ) + stare_variant: str = field( + default="O1", + metadata={ + "help": "STARE token-reweighting variant. `'O1'` (default, paper Section 4.2) amplifies positive-advantage " + "entropy-critical tokens only. `'C2'` additionally damps negative-advantage entropy-critical tokens. " + "Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236)." + }, + ) + stare_top_p_ratio: float = field( + default=0.1, + metadata={ + "help": "Fraction of highest-surprisal tokens (within each advantage-sign partition) treated as " + "entropy-critical and reweighted by STARE. Paper Section 4.1 default is `0.1`." + }, + ) + stare_reweight_w: float = field( + default=1.1, + metadata={ + "help": "Multiplicative weight `W > 1` applied to the per-token PG loss of positive-advantage " + "entropy-critical tokens (the set L+). Paper Section 4.2 default is `1.1`." + }, + ) + stare_reweight_m: float = field( + default=0.9, + metadata={ + "help": "Multiplicative weight `M < 1` applied to the per-token PG loss of negative-advantage " + "entropy-critical tokens (the set L-). Only active when `stare_variant='C2'`. Paper Section 4.2 default " + "is `0.9`." + }, + ) + stare_target_entropy: float = field( + default=0.3, + metadata={ + "help": "Target policy entropy floor for STARE's closed-loop gate (paper Section 4.3). The gate opens " + "(reweighting is active) when the batch-mean policy entropy falls below this value; otherwise weights " + "revert to 1 and STARE degrades to vanilla GRPO. Paper default is `0.3`." + }, + ) importance_sampling_level: str = field( default="token", metadata={ @@ -790,6 +849,10 @@ class GRPOConfig(_BaseConfig): "'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, " "asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in " "the [VESPO paper](https://huggingface.co/papers/2602.10693)." + "'stare': Surprisal-Guided Token-Level Advantage Reweighting for Policy Entropy Stability. Wraps the " + "GRPO dual-clip surrogate and reweights the per-token PG loss of entropy-critical tokens (the " + "highest-surprisal tokens within each advantage-sign partition), gated by a closed-loop target-entropy " + "signal. Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236)." }, ) mask_truncated_completions: bool = field( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 806d6674c57..e2e920c8ff7 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -2592,6 +2592,106 @@ def get_gamma_weights( return phi_seq # (B, 1) + @staticmethod + @torch.no_grad() + def compute_stare_token_weights( + per_token_logps: torch.Tensor, + advantages: torch.Tensor, + entropies: torch.Tensor, + mask: torch.Tensor, + variant: str = "O1", + top_p_ratio: float = 0.1, + reweight_w: float = 1.1, + reweight_m: float = 0.9, + target_entropy: float = 0.3, + ) -> tuple[torch.Tensor, dict[str, float]]: + """ + Compute STARE per-token PG-loss reweighting factors (paper Sections 4.1-4.3). + + Splits valid response tokens by trajectory-level advantage sign into T+ = {A > 0} and T- = {A < 0}. Within + each set, ranks tokens in descending order of surprisal ``s = -log pi_theta(o)`` and selects the top + ``top_p_ratio`` to form the entropy-critical sets L+ (and L- for variant C2). The returned weights ``omega`` + multiply the per-token PG loss: + + - Variant ``"O1"`` (default, one-sided amplification): omega = W on L+, else 1. + - Variant ``"C2"`` (dual-sided regulation): omega = W on L+, M on L-, else 1. + + Gated by the closed-loop target-entropy signal: when the batch-mean policy entropy is at or above + ``target_entropy``, the gate stays closed and all weights revert to 1 (STARE degrades to vanilla GRPO). + + Args: + per_token_logps: ``(B, T)`` current-policy token log-probs. Surprisal is ``-per_token_logps``. + advantages: ``(B, T)`` per-token advantages (the shared trajectory-level GRPO advantage broadcast over + tokens). + entropies: ``(B, T)`` per-token entropy from the actor forward pass. Used to compute the batch-mean + entropy ``H_bar`` that drives the closed-loop gate. + mask: ``(B, T)`` response-token mask (1 for response tokens). + variant: ``"O1"`` (default) or ``"C2"`` per paper Section 4.2. + top_p_ratio: Fraction of entropy-critical tokens to reweight within each partition. + reweight_w: Multiplicative weight ``W > 1`` for L+ (positive-advantage critical tokens). + reweight_m: Multiplicative weight ``M < 1`` for L- (negative-advantage critical tokens; ``"C2"`` only). + target_entropy: Entropy floor for the closed-loop gate (paper Section 4.3). + + Returns: + weights: ``(B, T)`` reweighting factors (default 1.0). Scales the per-token PG loss in-place. + stats: Scalar metrics describing the selection and the applied weights. + """ + if variant not in {"O1", "C2"}: + raise ValueError(f"Unsupported STARE variant: {variant!r}. Expected 'O1' or 'C2'.") + if not 0.0 <= top_p_ratio <= 1.0: + raise ValueError(f"Invalid stare_top_p_ratio: {top_p_ratio}. Expected a value in [0, 1].") + + weights = torch.ones_like(advantages) + valid = mask > 0 + valid_count = valid.float().sum().clamp_min(1.0) + + # Closed-loop gate: batch-mean policy entropy vs target (paper Section 4.3). + batch_mean_entropy = (entropies.detach().float() * valid.float()).sum() / valid_count + gate_on = bool(batch_mean_entropy.item() < target_entropy) + + positive_candidates = valid & (advantages > 0) + negative_candidates = valid & (advantages < 0) + + # Surprisal s = -log pi_theta(o); detached because it only drives token selection. + surprisal = -per_token_logps.detach() + + def top_surprisal_mask(candidate_mask: torch.Tensor) -> torch.Tensor: + """Select the top-``top_p_ratio`` highest-surprisal tokens within ``candidate_mask``.""" + candidate_count = int(candidate_mask.sum().item()) + selected = torch.zeros_like(candidate_mask, dtype=torch.bool) + if candidate_count == 0 or top_p_ratio <= 0.0: + return selected + # ceil(candidate_count * top_p_ratio), at least 1 token. + k = max(1, math.ceil(candidate_count * top_p_ratio)) + k = min(k, candidate_count) + candidate_values = surprisal[candidate_mask] + _, topk_local_idx = torch.topk(candidate_values, k=k, largest=True) + flat_candidate_idx = candidate_mask.reshape(-1).nonzero(as_tuple=False).squeeze(-1) + flat_selected_idx = flat_candidate_idx[topk_local_idx] + selected.reshape(-1)[flat_selected_idx] = True + return selected + + l_plus = top_surprisal_mask(positive_candidates) + l_minus = ( + top_surprisal_mask(negative_candidates) if variant == "C2" else torch.zeros_like(valid, dtype=torch.bool) + ) + + if gate_on: + weights = torch.where(l_plus, torch.full_like(weights, reweight_w), weights) + if variant == "C2": + weights = torch.where(l_minus, torch.full_like(weights, reweight_m), weights) + + stats = { + "stare/gate_on": float(gate_on), + "stare/batch_entropy": float(batch_mean_entropy.item()), + "stare/positive_reweight_ratio": float((l_plus.float() * valid.float()).sum().item() / valid_count.item()), + "stare/negative_reweight_ratio": float( + (l_minus.float() * valid.float()).sum().item() / valid_count.item() + ), + "stare/mean_weight": float((weights * valid.float()).sum().item() / valid_count.item()), + } + return weights, stats + def _compute_loss(self, model, inputs): # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] @@ -2707,6 +2807,28 @@ def _compute_loss(self, model, inputs): lambda_neg=self.args.vespo_lambda_neg, ) per_token_loss = -phi_seq * advantages * per_token_logps + elif self.loss_type == "stare": + # Vanilla GRPO dual-clip surrogate (paper Section 4.2 / Algorithm 1 base). + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + per_token_loss1 = coef_1 * advantages + per_token_loss2 = coef_2 * advantages + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + # STARE token-level PG-loss reweighting (paper Sections 4.1-4.3). + stare_weights, stare_stats = self.compute_stare_token_weights( + per_token_logps=per_token_logps, + advantages=advantages, + entropies=entropies, + mask=mask, + variant=self.args.stare_variant, + top_p_ratio=self.args.stare_top_p_ratio, + reweight_w=self.args.stare_reweight_w, + reweight_m=self.args.stare_reweight_m, + target_entropy=self.args.stare_target_entropy, + ) + per_token_loss = per_token_loss * stare_weights else: raise ValueError(f"Unknown loss type: {self.loss_type}") @@ -2727,7 +2849,7 @@ def _compute_loss(self, model, inputs): loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval loss = loss / normalizer - elif self.loss_type == "bnpo": + elif self.loss_type in ["bnpo", "stare"]: loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval loss = loss / normalizer @@ -2768,7 +2890,7 @@ def masked_batch_mean(x): mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) - if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"]: + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo", "luspo", "stare"]: # Compute the clipped probability ratios is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages < 0) is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages > 0) @@ -2795,6 +2917,10 @@ def masked_batch_mean(x): gathered_phi_seq = self.accelerator.gather(phi_seq) self._metrics[mode]["vespo/phi_seq_mean"].append(gathered_phi_seq.nanmean().item()) + if self.loss_type == "stare": + for stat_name, stat_value in stare_stats.items(): + self._metrics[mode][stat_name].append(stat_value) + return loss # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and From 22337c6153449fec1064e2eb23c1fc1c33842490 Mon Sep 17 00:00:00 2001 From: "remyx-ai[bot]" <289541483+remyx-ai[bot]@users.noreply.github.com> Date: Wed, 24 Jun 2026 09:48:45 -0700 Subject: [PATCH 2/2] fix(`grpo_trainer.py`): reject `loss_type='stare'` + `use_liger_kernel=True` Address Cursor Bugbot review on PR #6167: the fused LigerFusedLinearGRPOLoss path bypasses _compute_loss, so STARE's per-token reweighting was silently not applied. Add an early ValueError in __init__ matching the existing VESPO pre-check pattern, skip the (stare, use_liger_kernel=True) combination in test_train_loss_types, and add a dedicated negative test test_stare_rejects_liger_kernel. Co-Authored-By: smellslikeml <9044907+smellslikeml@users.noreply.github.com> Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/test_grpo_trainer.py | 22 ++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 8 ++++++++ 2 files changed, 30 insertions(+) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 8e6798e0599..0035a351f19 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -539,6 +539,10 @@ def reward_func(completions, **kwargs): @pytest.mark.parametrize("use_liger_kernel", [False, pytest.param(True, marks=require_liger_kernel)]) @pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo", "stare"]) def test_train_loss_types(self, loss_type, use_liger_kernel): + if loss_type == "stare" and use_liger_kernel: + pytest.skip( + "STARE is not yet supported with use_liger_kernel=True (validated by test_stare_rejects_liger_kernel)" + ) dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") training_args = GRPOConfig( @@ -571,6 +575,24 @@ def test_train_loss_types(self, loss_type, use_liger_kernel): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_stare_rejects_liger_kernel(self): + # STARE's per-token reweighting is not implemented in the fused Liger GRPO loss, so the combination must + # raise early rather than silently bypass the STARE objective. + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with pytest.raises(ValueError, match="STARE loss is not yet supported with `use_liger_kernel=True`"): + GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=GRPOConfig( + output_dir=self.tmp_dir, + loss_type="stare", + use_liger_kernel=True, + report_to="none", + ), + train_dataset=dataset, + ) + def test_train_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e2e920c8ff7..d158dee8929 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -712,6 +712,14 @@ def __init__( f"'token_mask'. Got: {self.vllm_importance_sampling_mode}." ) + if self.loss_type == "stare" and args.use_liger_kernel: + raise ValueError( + "STARE loss is not yet supported with `use_liger_kernel=True`. The fused Liger GRPO loss does not " + "currently apply STARE's per-token reweighting, so enabling both would silently bypass the STARE " + "objective. Set `use_liger_kernel=False` (the default) to use `loss_type='stare'`, or open a " + "follow-up issue if Liger-fused STARE is needed." + ) + # Multi-step self.num_iterations = args.num_iterations # = ๐œ‡ in the GRPO paper self.epsilon_low = args.epsilon