Skip to content

feat(grpo_trainer.py): STARE — Surprisal-guided Token-Level Advantage Reweighting#6167

Open
smellslikeml wants to merge 2 commits into
huggingface:mainfrom
smellslikeml:stare-surprisal-guided-token-level-advantage-reweighting-for
Open

feat(grpo_trainer.py): STARE — Surprisal-guided Token-Level Advantage Reweighting#6167
smellslikeml wants to merge 2 commits into
huggingface:mainfrom
smellslikeml:stare-surprisal-guided-token-level-advantage-reweighting-for

Conversation

@smellslikeml

@smellslikeml smellslikeml commented Jun 24, 2026

Copy link
Copy Markdown

What does this PR do?

This PR implements the STARE policy loss as a new loss_type on GRPOTrainer, following the same dispatch convention as dr_grpo, sapo, luspo, vespo, etc.

Paper: https://huggingface.co/papers/2606.19236
Official implementation: https://github.com/hp-luo/STARE — specifically verl/trainer/ppo/core_algos.py::compute_policy_loss_stare and STAREController (Apache-2.0).

The STARE objective wraps the GRPO dual-clip surrogate with a per-token PG-loss reweighting factor omega:

  • Token partitioning (paper §4.1): Split valid response tokens by sign(A) into T+ = {A > 0} and T- = {A < 0}. Within each set, rank tokens in descending order of surprisal s = -log π_θ(o) and select the top stare_top_p_ratio to form the entropy-critical sets L+ (and L- for variant C2).
  • Advantage reweighting (paper §4.2): omega = W (> 1) on L+ (variant O1, default) or omega = W on L+ and omega = M (< 1) on L- (variant C2 dual-sided regulation). Multiplies the per-token PG loss post-clip.
  • Closed-loop gate (paper §4.3): The reweighting is active only when batch-mean policy entropy H_bar < stare_target_entropy. Otherwise weights revert to 1 and STARE degrades to vanilla GRPO.

Paper defaults are wired in: stare_top_p_ratio=0.1, stare_reweight_w=1.1, stare_reweight_m=0.9, stare_target_entropy=0.3, stare_variant="O1".

Aggregation note: STARE joins the bnpo aggregation branch — matches verl's loss_agg_mode="token-mean" default for compute_policy_loss_stare, which is per-token mean over valid tokens.

Out of scope for this PR (paper §4.4): Adaptive W/M decay across batches (the optional STAREController adaptive path in the reference). Can be added as a follow-up if there is interest — the controller would be stateful across steps and is gated behind stare_adaptive=False in the reference.

Tests

  • Added "stare" to the loss_type parametrize in test_train_loss_types so STARE runs through the full GRPO training-step path.
  • Added TestComputeStareTokenWeights with 8 dedicated math-property tests covering: gate-closed behavior, O1 single-sided amplification, C2 dual-sided regulation, mask-exclusion of padded tokens, top-K selection cardinality, and input validation.
$ pytest tests/test_grpo_trainer.py::TestComputeStareTokenWeights -q
........                                                                  [100%]
8 passed in 0.4s

A 1-step end-to-end smoke test on trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 (CPU, paper defaults) completes a full forward + backward + optimizer step, with all 27 parameters updating and STARE metrics surfaced in the training log line:

train_loss: 5.96e-08   grad_norm: 0.195   params_moved: 27/27
stare/gate_on: 0       stare/batch_entropy: 11.93
stare/positive_reweight_ratio: 0.036   stare/mean_weight: 1.0

The closed-loop gate behaves correctly here: random-init model has batch entropy ≈11.93 >> target 0.3, so the gate stays closed and STARE degrades to vanilla GRPO (paper §4.3 invariant). Gate-open behavior is exercised in the unit tests.

Code conventions

  • compute_stare_token_weights is a @staticmethod @torch.no_grad() helper colocated with get_gamma_weights (VESPO precedent).
  • paper_index.md entry added (matches VESPO/SAPO/GMPO format) including the STARE objective in LaTeX.
  • Configuration parameters added to GRPOConfig with full docstrings and field(..., metadata={"help": ...}) patterns.
  • ruff check and ruff format clean on all changed files.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.


Note

Medium Risk
Changes core GRPO loss dispatch and training dynamics for users who opt in via loss_type="stare"; mitigated by validation, unit tests, and explicit rejection of unsupported Liger fusion.

Overview
Adds STARE as a new loss_type="stare" on GRPOTrainer, alongside existing paper losses (VESPO, SAPO, etc.).

GRPOConfig gains stare_variant, stare_top_p_ratio, stare_reweight_w, stare_reweight_m, and stare_target_entropy (paper defaults). compute_stare_token_weights picks high-surprisal tokens within positive/negative advantage partitions, applies multiplicative PG-loss weights when batch mean entropy is below the target (closed-loop gate), and supports O1 vs C2 variants.

_compute_loss builds the standard GRPO dual-clip surrogate, then multiplies by STARE weights; aggregation follows bnpo (token-mean). STARE metrics are logged; use_liger_kernel=True is rejected so the objective is not silently skipped.

Docs add a paper_index entry; tests cover weight math, training smoke with "stare", and the Liger incompatibility.

Reviewed by Cursor Bugbot for commit 22337c6. Bugbot is set up for automated code reviews on this repo. Configure here.

…ge Reweighting

Implements the STARE policy loss (paper §4.1-4.3) as a new `loss_type` on
GRPOTrainer, following the same dispatch convention as `dr_grpo`, `sapo`,
`luspo`, `vespo`, etc.

Paper: https://huggingface.co/papers/2606.19236
Official implementation: https://github.com/hp-luo/STARE
  (verl/trainer/ppo/core_algos.py::compute_policy_loss_stare + STAREController,
  Apache-2.0)

The STARE objective wraps the GRPO dual-clip surrogate with a per-token
PG-loss reweighting factor omega:
  - Partition valid response tokens by sign(A) into T+ / T-
  - Select top stare_top_p_ratio of each by surprisal s = -log pi_theta(o)
    to form the entropy-critical sets L+ / L-
  - omega = W on L+ (variant O1, default); omega = W on L+ and M on L-
    (variant C2 dual-sided); 1 otherwise
  - Closed-loop gate: weights apply only when batch-mean policy entropy
    < stare_target_entropy; otherwise revert to 1 (degrades to GRPO)

Paper defaults are wired in (top_p=0.1, W=1.1, M=0.9, target_entropy=0.3).
STARE aggregation follows the bnpo branch (matches verl's `token-mean`
default for compute_policy_loss_stare).

Co-Authored-By: smellslikeml <9044907+smellslikeml@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Want higher recall? High effort reviews run extra passes and find more bugs. A team admin can switch effort levels in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit a0827b0. Configure here.

Comment thread trl/trainer/grpo_trainer.py
selected.reshape(-1)[flat_selected_idx] = True
return selected

l_plus = top_surprisal_mask(positive_candidates)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-k pooled across batch

Medium Severity

top_surprisal_mask counts all masked candidates in the entire (B, T) tensor and runs one global topk, so the ceil(count * top_p_ratio) budget is shared across sequences. Longer or higher-surprisal rows can take every slot, leaving shorter completions with no entropy-critical tokens despite the paper’s per-response top-fraction selection.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit a0827b0. Configure here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional — top_surprisal_mask is a verbatim port of the reference implementation (hp-luo/STARE/verl/trainer/ppo/core_algos.py::compute_stare_token_weights), which pools candidates across the (B, T) batch within each advantage-sign partition. The paper's specification matches:

"Split response tokens by trajectory-level advantage sign into T+ = {A_i > 0} and T- = {A_i < 0}. Within each set, rank tokens in descending order of surprisal and select the top-P%..."

— where T+ / T- are batch-wide sets (a trajectory contributes all its tokens to whichever set its A_i sign assigns). This is the design: tokens are entropy-critical relative to the batch surprisal distribution, not their own row. Tokens that are "locally" high-surprisal but globally low-surprisal don't contribute much to entropy change, which is the §4.1 first-order analysis premise — so they shouldn't be reweighted.

The same batch-pool-by-sign approach is also used by the other two known PyTorch ports of this paper:

  • OpenRLHF smellslikeml/OpenRLHF#4kthvalue on flattened batch surprisals.
  • TRL (this PR's earlier PPOTrainer-shaped revision) — torch.quantile on flattened.

Happy to add a # Batch-pooled by sign (matches reference and paper §4.1) comment to make the design choice explicit at the call site if useful.

…l=True`

Address Cursor Bugbot review on PR huggingface#6167: the fused LigerFusedLinearGRPOLoss
path bypasses _compute_loss, so STARE's per-token reweighting was silently
not applied. Add an early ValueError in __init__ matching the existing VESPO
pre-check pattern, skip the (stare, use_liger_kernel=True) combination in
test_train_loss_types, and add a dedicated negative test
test_stare_rejects_liger_kernel.

Co-Authored-By: smellslikeml <9044907+smellslikeml@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant