feat(grpo_trainer.py): STARE — Surprisal-guided Token-Level Advantage Reweighting#6167
Conversation
…ge Reweighting Implements the STARE policy loss (paper §4.1-4.3) as a new `loss_type` on GRPOTrainer, following the same dispatch convention as `dr_grpo`, `sapo`, `luspo`, `vespo`, etc. Paper: https://huggingface.co/papers/2606.19236 Official implementation: https://github.com/hp-luo/STARE (verl/trainer/ppo/core_algos.py::compute_policy_loss_stare + STAREController, Apache-2.0) The STARE objective wraps the GRPO dual-clip surrogate with a per-token PG-loss reweighting factor omega: - Partition valid response tokens by sign(A) into T+ / T- - Select top stare_top_p_ratio of each by surprisal s = -log pi_theta(o) to form the entropy-critical sets L+ / L- - omega = W on L+ (variant O1, default); omega = W on L+ and M on L- (variant C2 dual-sided); 1 otherwise - Closed-loop gate: weights apply only when batch-mean policy entropy < stare_target_entropy; otherwise revert to 1 (degrades to GRPO) Paper defaults are wired in (top_p=0.1, W=1.1, M=0.9, target_entropy=0.3). STARE aggregation follows the bnpo branch (matches verl's `token-mean` default for compute_policy_loss_stare). Co-Authored-By: smellslikeml <9044907+smellslikeml@users.noreply.github.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit a0827b0. Configure here.
| selected.reshape(-1)[flat_selected_idx] = True | ||
| return selected | ||
|
|
||
| l_plus = top_surprisal_mask(positive_candidates) |
There was a problem hiding this comment.
Top-k pooled across batch
Medium Severity
top_surprisal_mask counts all masked candidates in the entire (B, T) tensor and runs one global topk, so the ceil(count * top_p_ratio) budget is shared across sequences. Longer or higher-surprisal rows can take every slot, leaving shorter completions with no entropy-critical tokens despite the paper’s per-response top-fraction selection.
Reviewed by Cursor Bugbot for commit a0827b0. Configure here.
There was a problem hiding this comment.
This is intentional — top_surprisal_mask is a verbatim port of the reference implementation (hp-luo/STARE/verl/trainer/ppo/core_algos.py::compute_stare_token_weights), which pools candidates across the (B, T) batch within each advantage-sign partition. The paper's specification matches:
"Split response tokens by trajectory-level advantage sign into
T+ = {A_i > 0}andT- = {A_i < 0}. Within each set, rank tokens in descending order of surprisal and select the top-P%..."
— where T+ / T- are batch-wide sets (a trajectory contributes all its tokens to whichever set its A_i sign assigns). This is the design: tokens are entropy-critical relative to the batch surprisal distribution, not their own row. Tokens that are "locally" high-surprisal but globally low-surprisal don't contribute much to entropy change, which is the §4.1 first-order analysis premise — so they shouldn't be reweighted.
The same batch-pool-by-sign approach is also used by the other two known PyTorch ports of this paper:
- OpenRLHF smellslikeml/OpenRLHF#4 —
kthvalueon flattened batch surprisals. - TRL (this PR's earlier
PPOTrainer-shaped revision) —torch.quantileon flattened.
Happy to add a # Batch-pooled by sign (matches reference and paper §4.1) comment to make the design choice explicit at the call site if useful.
…l=True` Address Cursor Bugbot review on PR huggingface#6167: the fused LigerFusedLinearGRPOLoss path bypasses _compute_loss, so STARE's per-token reweighting was silently not applied. Add an early ValueError in __init__ matching the existing VESPO pre-check pattern, skip the (stare, use_liger_kernel=True) combination in test_train_loss_types, and add a dedicated negative test test_stare_rejects_liger_kernel. Co-Authored-By: smellslikeml <9044907+smellslikeml@users.noreply.github.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>


What does this PR do?
This PR implements the STARE policy loss as a new
loss_typeonGRPOTrainer, following the same dispatch convention asdr_grpo,sapo,luspo,vespo, etc.Paper: https://huggingface.co/papers/2606.19236
Official implementation: https://github.com/hp-luo/STARE — specifically
verl/trainer/ppo/core_algos.py::compute_policy_loss_stareandSTAREController(Apache-2.0).The STARE objective wraps the GRPO dual-clip surrogate with a per-token PG-loss reweighting factor
omega:sign(A)intoT+ = {A > 0}andT- = {A < 0}. Within each set, rank tokens in descending order of surprisals = -log π_θ(o)and select the topstare_top_p_ratioto form the entropy-critical setsL+(andL-for variant C2).omega = W (> 1)onL+(variant O1, default) oromega = WonL+andomega = M (< 1)onL-(variant C2 dual-sided regulation). Multiplies the per-token PG loss post-clip.H_bar < stare_target_entropy. Otherwise weights revert to 1 and STARE degrades to vanilla GRPO.Paper defaults are wired in:
stare_top_p_ratio=0.1,stare_reweight_w=1.1,stare_reweight_m=0.9,stare_target_entropy=0.3,stare_variant="O1".Aggregation note: STARE joins the
bnpoaggregation branch — matches verl'sloss_agg_mode="token-mean"default forcompute_policy_loss_stare, which is per-token mean over valid tokens.Out of scope for this PR (paper §4.4): Adaptive
W/Mdecay across batches (the optionalSTAREControlleradaptive path in the reference). Can be added as a follow-up if there is interest — the controller would be stateful across steps and is gated behindstare_adaptive=Falsein the reference.Tests
"stare"to theloss_typeparametrize intest_train_loss_typesso STARE runs through the full GRPO training-step path.TestComputeStareTokenWeightswith 8 dedicated math-property tests covering: gate-closed behavior, O1 single-sided amplification, C2 dual-sided regulation, mask-exclusion of padded tokens, top-K selection cardinality, and input validation.A 1-step end-to-end smoke test on
trl-internal-testing/tiny-Qwen2ForCausalLM-2.5(CPU, paper defaults) completes a full forward + backward + optimizer step, with all 27 parameters updating and STARE metrics surfaced in the training log line:The closed-loop gate behaves correctly here: random-init model has batch entropy ≈11.93 >> target 0.3, so the gate stays closed and STARE degrades to vanilla GRPO (paper §4.3 invariant). Gate-open behavior is exercised in the unit tests.
Code conventions
compute_stare_token_weightsis a@staticmethod @torch.no_grad()helper colocated withget_gamma_weights(VESPO precedent).paper_index.mdentry added (matches VESPO/SAPO/GMPO format) including the STARE objective in LaTeX.GRPOConfigwith full docstrings andfield(..., metadata={"help": ...})patterns.ruff checkandruff formatclean on all changed files.Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Note
Medium Risk
Changes core GRPO loss dispatch and training dynamics for users who opt in via
loss_type="stare"; mitigated by validation, unit tests, and explicit rejection of unsupported Liger fusion.Overview
Adds STARE as a new
loss_type="stare"onGRPOTrainer, alongside existing paper losses (VESPO, SAPO, etc.).GRPOConfiggainsstare_variant,stare_top_p_ratio,stare_reweight_w,stare_reweight_m, andstare_target_entropy(paper defaults).compute_stare_token_weightspicks high-surprisal tokens within positive/negative advantage partitions, applies multiplicative PG-loss weights when batch mean entropy is below the target (closed-loop gate), and supports O1 vs C2 variants._compute_lossbuilds the standard GRPO dual-clip surrogate, then multiplies by STARE weights; aggregation follows bnpo (token-mean). STARE metrics are logged;use_liger_kernel=Trueis rejected so the objective is not silently skipped.Docs add a paper_index entry; tests cover weight math, training smoke with
"stare", and the Liger incompatibility.Reviewed by Cursor Bugbot for commit 22337c6. Bugbot is set up for automated code reviews on this repo. Configure here.