Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,40 @@ training_args = GRPOConfig(
```


### STARE: Surprisal-Guided Token-Level Advantage Reweighting for Policy Entropy Stability

**📜 Paper**: https://huggingface.co/papers/2606.19236

STARE addresses policy entropy collapse in PPO/GRPO-style RL on language models. A first-order analysis of per-token entropy dynamics shows that the change in entropy is proportional to the product of a token's advantage and an entropy-sensitivity term that grows with the token's surprisal \\( s_{i,t} = -\log \pi_\theta(o_{i,t}) \\) — so updates on high-surprisal tokens dominate how fast the policy loses (or keeps) exploration. STARE identifies the *entropy-critical token subset* via a top-fraction-by-surprisal selection within each advantage-sign partition and reweights the per-token PG loss of those tokens, gated by a closed-loop target-entropy signal.

The STARE objective wraps the standard GRPO dual-clip surrogate with a token-level reweighting factor \\( \omega_{i,t} \\):

$$
\mathcal{J}_{\text{STARE}}(\theta) = \mathbb{E}_{q, \{o_i\}} \left[ \frac{1}{\sum_i |o_i|} \sum_{i=1}^{G} \sum_{t=1}^{|o_i|} \omega_{i,t} \cdot \min\left( w_{i,t}(\theta) \hat{A}_{i,t},\ \operatorname{clip}(w_{i,t}(\theta), 1-\epsilon, 1+\epsilon)\, \hat{A}_{i,t} \right) \right]
$$

with \\( w_{i,t}(\theta) = \pi_\theta(o_{i,t} \mid q, o_{i,<t}) / \pi_{\text{old}}(o_{i,t} \mid q, o_{i,<t}) \\) the per-token importance ratio. The reweighting factor is

$$
\omega_{i,t} = \begin{cases} W > 1 & \text{if } (i,t) \in L^+ \text{ (positive-advantage critical tokens)} \\ M < 1 & \text{if } (i,t) \in L^- \text{ and variant } = C2 \\ 1 & \text{otherwise} \end{cases}
$$

active only when the closed-loop gate is open: \\( \mathbb{1}[\bar{H}_k < H_{\text{tgt}}] \\), where \\( \bar{H}_k \\) is the batch-mean policy entropy and \\( H_{\text{tgt}} \\) is the target floor. When the gate stays closed, all weights revert to 1 and STARE degrades to vanilla GRPO.

```python
from trl import GRPOConfig

training_args = GRPOConfig(
loss_type="stare",
stare_variant="O1", # "O1" (one-sided amplification, default) or "C2" (dual-sided regulation)
stare_top_p_ratio=0.1, # paper Section 4.1 default
stare_reweight_w=1.1, # multiplicative weight for L+ (positive-advantage critical tokens)
stare_reweight_m=0.9, # multiplicative weight for L- (only active when variant="C2")
stare_target_entropy=0.3, # closed-loop gate floor (paper Section 4.3)
)
```


### Rethinking the Trust Region in LLM Reinforcement Learning

**📜 Paper**: https://huggingface.co/papers/2602.04879
Expand Down
146 changes: 145 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,150 @@ def test_compute_entropy_all_masked(self):
torch.testing.assert_close(entropy_mask, expected_mask)


class TestComputeStareTokenWeights(TrlTestCase):
"""Math properties of STARE per-token PG-loss reweighting (paper Sections 4.1-4.3)."""

def _logps(self, probs):
return torch.log(torch.tensor(probs, dtype=torch.float32))

def _low_entropy(self, shape):
# Far below the default target_entropy=0.3 → closed-loop gate opens.
return torch.full(shape, 0.01)

def _high_entropy(self, shape):
# Above target_entropy=0.3 → gate stays closed.
return torch.full(shape, 0.5)

def test_gate_closed_returns_unit_weights(self):
# When batch-mean entropy >= target, all weights revert to 1 and STARE degrades to vanilla GRPO.
logps = self._logps([[-0.1, -0.2, -3.0, -4.0]])
advantages = torch.tensor([[1.0, 1.0, 1.0, 1.0]])
mask = torch.ones_like(advantages)
weights, stats = GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._high_entropy(logps.shape),
mask=mask,
target_entropy=0.3,
)
torch.testing.assert_close(weights, torch.ones_like(weights))
assert stats["stare/gate_on"] == 0.0

def test_variant_o1_reweights_only_positive_critical_tokens(self):
# Two positive-advantage tokens; the higher-surprisal one is in L+ and gets reweight_w.
logps = self._logps([[-0.1, -3.0]]) # surprisal: [0.1, 3.0] → token 1 is critical
advantages = torch.tensor([[1.0, 1.0]])
mask = torch.ones_like(advantages)
weights, stats = GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape),
mask=mask,
variant="O1",
top_p_ratio=0.5, # 50% → ceil(2*0.5)=1 critical token per partition
reweight_w=1.5,
reweight_m=0.9,
)
torch.testing.assert_close(weights, torch.tensor([[1.0, 1.5]]))
assert stats["stare/gate_on"] == 1.0

def test_variant_o1_ignores_negative_advantage_tokens(self):
# O1 should leave negative-advantage tokens at weight 1 regardless of surprisal.
logps = self._logps([[-3.0, -3.0]])
advantages = torch.tensor([[1.0, -1.0]]) # one positive, one negative
mask = torch.ones_like(advantages)
weights, _ = GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape),
mask=mask,
variant="O1",
top_p_ratio=1.0,
reweight_w=1.5,
reweight_m=0.5,
)
# Positive-advantage token reweighted by W; negative-advantage token untouched.
torch.testing.assert_close(weights, torch.tensor([[1.5, 1.0]]))

def test_variant_c2_reweights_both_partitions(self):
# C2 also damps negative-advantage critical tokens.
logps = self._logps([[-3.0, -3.0]])
advantages = torch.tensor([[1.0, -1.0]])
mask = torch.ones_like(advantages)
weights, _ = GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape),
mask=mask,
variant="C2",
top_p_ratio=1.0,
reweight_w=1.5,
reweight_m=0.5,
)
torch.testing.assert_close(weights, torch.tensor([[1.5, 0.5]]))

def test_mask_excludes_padded_tokens(self):
# Padded positions must be excluded from candidate sets even when they have high surprisal.
logps = self._logps([[-0.1, -3.0, -5.0]]) # token 2 has highest surprisal but is padded
advantages = torch.tensor([[1.0, 1.0, 1.0]])
mask = torch.tensor([[1.0, 1.0, 0.0]]) # last token padded
weights, stats = GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape) * mask,
mask=mask,
variant="O1",
top_p_ratio=0.5, # ceil(2*0.5)=1 critical token out of 2 valid candidates
reweight_w=1.5,
)
# Token 1 (surprisal 3.0) is the highest-surprisal *valid* candidate.
torch.testing.assert_close(weights, torch.tensor([[1.0, 1.5, 1.0]]))
# Padded token contributes 0 to the reweight-ratio denominator.
assert stats["stare/positive_reweight_ratio"] == 0.5

def test_top_p_ratio_selects_correct_count(self):
# 4 positive-advantage tokens, top_p_ratio=0.5 → ceil(4*0.5)=2 critical tokens.
logps = self._logps([[-0.1, -0.2, -3.0, -4.0]]) # tokens 2, 3 have highest surprisal
advantages = torch.tensor([[1.0, 1.0, 1.0, 1.0]])
mask = torch.ones_like(advantages)
weights, _ = GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape),
mask=mask,
variant="O1",
top_p_ratio=0.5,
reweight_w=1.5,
)
torch.testing.assert_close(weights, torch.tensor([[1.0, 1.0, 1.5, 1.5]]))

def test_rejects_invalid_variant(self):
logps = self._logps([[-0.5]])
advantages = torch.tensor([[1.0]])
mask = torch.ones_like(advantages)
with pytest.raises(ValueError, match="Unsupported STARE variant"):
GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape),
mask=mask,
variant="X1",
)

def test_rejects_invalid_top_p_ratio(self):
logps = self._logps([[-0.5]])
advantages = torch.tensor([[1.0]])
mask = torch.ones_like(advantages)
with pytest.raises(ValueError, match="stare_top_p_ratio"):
GRPOTrainer.compute_stare_token_weights(
per_token_logps=logps,
advantages=advantages,
entropies=self._low_entropy(logps.shape),
mask=mask,
top_p_ratio=1.5,
)


class TestGRPORolloutDispatch:
def _make_trainer(self):
trainer = object.__new__(GRPOTrainer)
Expand Down Expand Up @@ -393,7 +537,7 @@ def reward_func(completions, **kwargs):
assert type(trainer.model).__name__ == "RemoteForCausalLM"

@pytest.mark.parametrize("use_liger_kernel", [False, pytest.param(True, marks=require_liger_kernel)])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo", "stare"])
def test_train_loss_types(self, loss_type, use_liger_kernel):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

Expand Down
63 changes: 63 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,23 @@ class GRPOConfig(_BaseConfig):
lambda parameter for negative advantages, it is the exponential decay factor in the VESPO loss. Controls
how aggressively we down-weight samples with high importance weights (when the importance sampling ratio >
1).
stare_variant (`str`, *optional*, defaults to `"O1"`):
STARE token-reweighting variant. `"O1"` (default, paper Section 4.2) amplifies positive-advantage
entropy-critical tokens only. `"C2"` additionally damps negative-advantage entropy-critical tokens.
Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236).
stare_top_p_ratio (`float`, *optional*, defaults to `0.1`):
Fraction of highest-surprisal tokens (within each advantage-sign partition) treated as entropy-critical
and reweighted by STARE. Paper Section 4.1 default is `0.1`.
stare_reweight_w (`float`, *optional*, defaults to `1.1`):
Multiplicative weight `W > 1` applied to the per-token PG loss of positive-advantage entropy-critical
tokens (the set L+). Paper Section 4.2 default is `1.1`.
stare_reweight_m (`float`, *optional*, defaults to `0.9`):
Multiplicative weight `M < 1` applied to the per-token PG loss of negative-advantage entropy-critical
tokens (the set L-). Only active when `stare_variant="C2"`. Paper Section 4.2 default is `0.9`.
stare_target_entropy (`float`, *optional*, defaults to `0.3`):
Target policy entropy floor for STARE's closed-loop gate (paper Section 4.3). The gate opens
(reweighting is active) when the batch-mean policy entropy falls below this value; otherwise weights
revert to 1 and STARE degrades to vanilla GRPO. Paper default is `0.3`.
importance_sampling_level (`str`, *optional*, defaults to `"token"`):
Controls whether importance sampling ratios are computed at the `"token"` or `"sequence"` level. `"token"`
keeps the raw per-token log-probability ratios (one weight per token). `"sequence"` averages the
Expand Down Expand Up @@ -272,6 +289,10 @@ class GRPOConfig(_BaseConfig):
- `"vespo"`: Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth,
asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in
the [VESPO paper](https://huggingface.co/papers/2602.10693).
- `"stare"`: Surprisal-Guided Token-Level Advantage Reweighting for Policy Entropy Stability. Wraps the
GRPO dual-clip surrogate and reweights the per-token PG loss of entropy-critical tokens (the
highest-surprisal tokens within each advantage-sign partition), gated by a closed-loop target-entropy
signal. Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236).
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
Expand Down Expand Up @@ -716,6 +737,44 @@ class GRPOConfig(_BaseConfig):
"sampling ratio > 1)."
},
)
stare_variant: str = field(
default="O1",
metadata={
"help": "STARE token-reweighting variant. `'O1'` (default, paper Section 4.2) amplifies positive-advantage "
"entropy-critical tokens only. `'C2'` additionally damps negative-advantage entropy-critical tokens. "
"Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236)."
},
)
stare_top_p_ratio: float = field(
default=0.1,
metadata={
"help": "Fraction of highest-surprisal tokens (within each advantage-sign partition) treated as "
"entropy-critical and reweighted by STARE. Paper Section 4.1 default is `0.1`."
},
)
stare_reweight_w: float = field(
default=1.1,
metadata={
"help": "Multiplicative weight `W > 1` applied to the per-token PG loss of positive-advantage "
"entropy-critical tokens (the set L+). Paper Section 4.2 default is `1.1`."
},
)
stare_reweight_m: float = field(
default=0.9,
metadata={
"help": "Multiplicative weight `M < 1` applied to the per-token PG loss of negative-advantage "
"entropy-critical tokens (the set L-). Only active when `stare_variant='C2'`. Paper Section 4.2 default "
"is `0.9`."
},
)
stare_target_entropy: float = field(
default=0.3,
metadata={
"help": "Target policy entropy floor for STARE's closed-loop gate (paper Section 4.3). The gate opens "
"(reweighting is active) when the batch-mean policy entropy falls below this value; otherwise weights "
"revert to 1 and STARE degrades to vanilla GRPO. Paper default is `0.3`."
},
)
importance_sampling_level: str = field(
default="token",
metadata={
Expand Down Expand Up @@ -790,6 +849,10 @@ class GRPOConfig(_BaseConfig):
"'vespo': Variational Sequence-Level Soft Policy Optimization. Replaces hard clipping with a smooth, "
"asymmetric Gamma weighting function applied directly to sequence-level importance weights. Introduced in "
"the [VESPO paper](https://huggingface.co/papers/2602.10693)."
"'stare': Surprisal-Guided Token-Level Advantage Reweighting for Policy Entropy Stability. Wraps the "
"GRPO dual-clip surrogate and reweights the per-token PG loss of entropy-critical tokens (the "
"highest-surprisal tokens within each advantage-sign partition), gated by a closed-loop target-entropy "
"signal. Introduced in the [STARE paper](https://huggingface.co/papers/2606.19236)."
},
)
mask_truncated_completions: bool = field(
Expand Down
Loading