Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ac50a11
Add fields to GRPOConfig
albertvillanova Jun 22, 2026
dcaaf67
Add init fields to GRPOTrainer
albertvillanova Jun 22, 2026
0f6306e
Update _compute_loss
albertvillanova Jun 22, 2026
9b1cc65
Add checkpoint persistence
albertvillanova Jun 22, 2026
e944713
Update GRPO docs
albertvillanova Jun 22, 2026
f47d5a5
Add tests
albertvillanova Jun 22, 2026
2484e70
Address issues from review
albertvillanova Jun 22, 2026
4507747
Fix wrong entropy for adaptive control
albertvillanova Jun 22, 2026
9b70a4a
Fix Liger skips adaptive entropy guard
albertvillanova Jun 22, 2026
9d79e4a
Fix inconsistent inequality
albertvillanova Jun 22, 2026
46c8a64
Fix mean reduction with sum-count-divide
albertvillanova Jun 22, 2026
3f7a669
Set _last_world_entropy at init
albertvillanova Jun 22, 2026
a05c979
Cache world_entropy at sync point and use that cached value for apply…
albertvillanova Jun 22, 2026
fe03dd1
Persist also _last_world_entropy
albertvillanova Jun 22, 2026
f099349
Add paper_index entry
albertvillanova Jun 22, 2026
5288cd5
Capture the pure policy loss before normalization
albertvillanova Jun 24, 2026
03f4208
Fix luspo loss
albertvillanova Jun 24, 2026
dbc0c75
Gate policy_loss logging and align style
albertvillanova Jun 24, 2026
391da7a
Merge remote-tracking branch 'upstream/main' into worktree-fix-3320
albertvillanova Jun 24, 2026
506fbf9
Fix entropy state written to wrong path
albertvillanova Jun 24, 2026
8a6b53d
Fix is_world_process_zero() vs args.should_save guard mismatch
albertvillanova Jun 24, 2026
474b30c
Update docs: policy_loss only logged inside entropy block
albertvillanova Jun 24, 2026
a0b9ec6
Log entropy_coef only when sync_gradients=True
albertvillanova Jun 24, 2026
608b1e0
Add guard for entropy-loss dispatch matching policy-loss dispatch
albertvillanova Jun 24, 2026
81841ad
Remove entropy_loss
albertvillanova Jun 24, 2026
bee5126
Gate on train mode to avoid entropy state update during eval
albertvillanova Jun 24, 2026
5c442a0
Merge remote-tracking branch 'upstream/main' into worktree-fix-3320
albertvillanova Jun 24, 2026
2f34d15
Fix entropy bonus ignores quantile mask
albertvillanova Jun 24, 2026
806078d
Use effective_mask for the world_entropy all-reduce too
albertvillanova Jun 24, 2026
2845ef4
Update docs
albertvillanova Jun 24, 2026
2ed11c0
Use unified formula with mean per-token entropy of active tokens
albertvillanova Jun 24, 2026
7f0562b
Merge remote-tracking branch 'upstream/main' into worktree-fix-3320
albertvillanova Jun 25, 2026
76255d3
Make three-branch entropy-loss split
albertvillanova Jun 25, 2026
fc76d4b
Compute bonus from frozen state, update per optimizer step
albertvillanova Jun 25, 2026
bed5188
Fix "nearly always triggers" docs
albertvillanova Jun 25, 2026
6e8f498
Add scale test and grad-accumulation adaptive test
albertvillanova Jun 25, 2026
607d911
Fix dr_grpo entropy scale mismatch
albertvillanova Jun 25, 2026
0cfad37
Accumulate to mean per-token entropy, independent of how each loss ty…
albertvillanova Jun 25, 2026
8e05132
Update tests
albertvillanova Jun 25, 2026
f15e04a
Merge remote-tracking branch 'upstream/main' into worktree-fix-3320
albertvillanova Jun 26, 2026
bccd8eb
Add clarifying sentence
albertvillanova Jun 26, 2026
0f3e145
Merge branch 'main' into worktree-fix-3320
qgallouedec Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ While training and evaluating, we record the following reward metrics:
- `reward`: The overall average reward after summing rewards across functions (weighted by `reward_weights`).
- `reward_std`: The standard deviation of summed rewards across functions (weighted by `reward_weights`), computed over the full batch.
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
- `policy_loss`: The policy gradient loss value (before any entropy bonus).
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
- `entropy_loss`: Mean per-token entropy (nats) used as the regularization signal. Logged when `entropy_coef` is nonzero or `use_adaptive_entropy=True`.
- `entropy_coef`: The current entropy coefficient. Logged when `entropy_coef` is nonzero or `use_adaptive_entropy=True`. Updated once per optimizer step when `use_adaptive_entropy=True`.
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region: \\( \text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \quad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \\). A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\).
Expand Down Expand Up @@ -641,6 +644,46 @@ and the reward will be computed as the sum of the rewards from each function, or

Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.

### Entropy regularization

To encourage exploration and prevent the policy from collapsing to near-deterministic outputs, you can add an entropy bonus to the training objective. The entropy regularization augments the GRPO loss as follows:

$$
\mathcal{L}(\theta) = \mathcal{L}_{\text{GRPO}}(\theta) - \alpha \cdot \mathcal{H}(\pi_\theta),
$$

where \\(\mathcal{H}(\pi_\theta)\\) is the mean per-token entropy of the policy and \\(\alpha\\) is the entropy coefficient.

**Static entropy** — a fixed coefficient throughout training:

```python
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(entropy_coef=0.05, ...)
```

**Adaptive entropy** — the coefficient is updated each optimizer step based on a target entropy, as introduced in [Skywork-OR1](https://huggingface.co/papers/2505.22312). When the current entropy falls at or below `entropy_target`, the coefficient is incremented by `entropy_coef_delta`; otherwise it is decremented. The coefficient is only applied (i.e. non-zero) while entropy is at or below the target:

```python
training_args = GRPOConfig(
entropy_coef=0.01, # initial coefficient
use_adaptive_entropy=True,
entropy_target=5.0, # target mean per-token entropy (nats); tune for your model
entropy_coef_delta=0.005, # step size per optimizer step
entropy_coef_min=0.0,
entropy_coef_max=1.0,
...
)
```

<Tip>

Typical language models have per-token entropies of 2–10 nats. The default `entropy_target=0.2` nearly always triggers regularization; set it to a value meaningful for your model (e.g. the entropy you observe early in training, logged as the `entropy` metric).

</Tip>

When `use_adaptive_entropy=True`, the current entropy coefficient `entropy_coef` is saved alongside each checkpoint and restored on resume, so training is fully resumable.

### Rapid Experimentation for GRPO

RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple GRPO configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you **see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight** without restarting. For more information, see [RapidFire AI Integration](rapidfire_integration).
Expand Down
68 changes: 68 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,74 @@ def test_train_with_cast_lm_head_to_fp32(self, model_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_train_with_static_entropy(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
entropy_coef=0.1,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
assert trainer.state.log_history[-1]["policy_loss"] is not None
assert trainer.state.log_history[-1]["entropy_loss"] is not None
assert trainer.state.log_history[-1]["entropy_coef"] is not None

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_train_with_adaptive_entropy(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
report_to="none",
entropy_coef=0.01,
use_adaptive_entropy=True,
entropy_target=15.0, # above any realistic entropy → coef is always incremented
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None
assert trainer.state.log_history[-1]["policy_loss"] is not None
assert trainer.state.log_history[-1]["entropy_loss"] is not None
assert trainer.state.log_history[-1]["entropy_coef"] is not None
# Coefficient should have increased since entropy < target throughout training
assert trainer.entropy_coef > 0.01

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_train_with_entropy_filter(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
Expand Down
58 changes: 58 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,28 @@ class GRPOConfig(_BaseConfig):
position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token;
`1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with
`mask_truncated_completions=True`, only tokens from non-truncated completions are considered.
entropy_coef (`float`, *optional*, defaults to `0.0`):
Coefficient of the entropy regularization term in the loss. A positive value adds an entropy bonus that
encourages exploration by keeping the policy from collapsing to near-deterministic outputs. When
`use_adaptive_entropy=True`, this serves as the initial coefficient and is updated each optimizer step.
Has no effect when set to `0.0` (default).
use_adaptive_entropy (`bool`, *optional*, defaults to `False`):
Whether to use adaptive entropy control, introduced in
[Skywork-OR1](https://huggingface.co/papers/2505.22312). When enabled, the entropy coefficient
`entropy_coef` is updated each optimizer step: incremented by `entropy_coef_delta` when the current
entropy is below `entropy_target`, and decremented otherwise. The coefficient is only applied when
entropy is at or below `entropy_target`.
Comment thread
cursor[bot] marked this conversation as resolved.
entropy_coef_min (`float`, *optional*, defaults to `0.0`):
Lower bound for the entropy coefficient when using adaptive entropy control.
entropy_coef_max (`float`, *optional*, defaults to `1.0`):
Upper bound for the entropy coefficient when using adaptive entropy control.
entropy_coef_delta (`float`, *optional*, defaults to `0.005`):
Step size for adjusting the entropy coefficient at each optimizer step during adaptive entropy control.
entropy_target (`float`, *optional*, defaults to `0.2`):
Target mean per-token entropy (in nats) used by adaptive entropy control. The coefficient is only
applied when the current entropy falls at or below this value. Typical language models have per-token
entropies in the range 2–10 nats; the default of `0.2` nearly always triggers regularization, so users
should tune this to a value appropriate for their model and task.
max_tool_calling_iterations (`int`, *optional*):
Maximum number of tool-calling turns when training an agent. If `None`, there is no limit and generation
stops when the model generates a response turn with no tool calls or when the total response length reaches
Expand Down Expand Up @@ -832,6 +854,42 @@ class GRPOConfig(_BaseConfig):
"non-truncated completions are considered."
},
)
entropy_coef: float = field(
default=0.0,
metadata={
"help": "Coefficient of the entropy regularization term in the loss. A positive value adds an entropy "
"bonus that encourages exploration. When `use_adaptive_entropy=True`, this serves as the initial "
"coefficient and is updated each optimizer step. Has no effect when set to `0.0` (default)."
},
)
use_adaptive_entropy: bool = field(
default=False,
metadata={
"help": "Whether to use adaptive entropy control, introduced in Skywork-OR1 "
"(https://huggingface.co/papers/2505.22312). When enabled, `entropy_coef` is incremented by "
"`entropy_coef_delta` when entropy is below `entropy_target`, and decremented otherwise."
},
)
entropy_coef_min: float = field(
default=0.0,
metadata={"help": "Lower bound for the entropy coefficient when using adaptive entropy control."},
)
entropy_coef_max: float = field(
default=1.0,
metadata={"help": "Upper bound for the entropy coefficient when using adaptive entropy control."},
)
entropy_coef_delta: float = field(
default=0.005,
metadata={"help": "Step size for adjusting the entropy coefficient during adaptive entropy control."},
)
entropy_target: float = field(
default=0.2,
metadata={
"help": "Target mean per-token entropy (nats) for adaptive entropy control. The coefficient is only "
"applied when current entropy is at or below this value. Typical language models have per-token "
"entropies of 2–10 nats; the default of 0.2 nearly always triggers regularization, so tune this."
},
)
max_tool_calling_iterations: int | None = field(
default=None,
metadata={
Expand Down
71 changes: 71 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import importlib.resources as pkg_resources
import inspect
import json
import math
import os
import sys
Expand Down Expand Up @@ -667,6 +668,13 @@ def __init__(
f"Unknown importance sampling level: {self.importance_sampling_level}. "
"Possible values are 'token' and 'sequence'."
)
self.entropy_coef = args.entropy_coef
self.use_adaptive_entropy = args.use_adaptive_entropy
# Cached entropy from the last optimizer step; inf so the first accumulation window
# applies no bonus until a real measurement arrives (conservative default).
self._last_world_entropy = float("inf")
Comment thread
cursor[bot] marked this conversation as resolved.
if self.use_liger_kernel and (self.entropy_coef != 0.0 or self.use_adaptive_entropy):
raise NotImplementedError("Entropy bonus is not supported with Liger kernel.")
Comment thread
cursor[bot] marked this conversation as resolved.

# Datasets
self.shuffle_dataset = args.shuffle_dataset
Expand Down Expand Up @@ -2736,6 +2744,51 @@ def _compute_loss(self, model, inputs):
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

# Capture the pure policy loss for logging before entropy/aux modify it
policy_loss = loss.detach()

# Entropy bonus: add entropy regularization to encourage exploration.
# Gate: run whenever a non-zero static coef is set OR adaptive mode is enabled. Adaptive must always run even
# when self.entropy_coef has been decremented to entropy_coef_min (default 0) so it can recover once entropy
# drops below entropy_target again.
if self.entropy_coef != 0.0 or self.use_adaptive_entropy:
if self.loss_type in ["grpo", "sapo", "luspo"]:
entropy_loss = ((entropies * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() / normalizer

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entropy block lumps luspo with grpo/sapo, but luspo's actual loss is (per_token_loss * mask.sum(1,keepdim=True)).mean(). So for luspo the entropy bonus lives on a different scale than the policy term and entropy_coef means something different there. Either give luspo its own branch or note it. To be confirmed, but probably something like:

entropy_loss = (entropies * mask).sum(-1).mean() / normalizer

elif self.loss_type == "bnpo":
entropy_loss = (entropies * mask).sum() / mask.sum().clamp(min=1.0) / normalizer
elif self.loss_type == "dr_grpo":
entropy_loss = (entropies * mask).sum() / (entropies.size(0) * self.max_completion_length) / normalizer
elif self.loss_type in ["cispo", "dapo", "vespo"]:
entropy_loss = (entropies * mask).sum() / normalizer

# True global mean per-token entropy (nats): reduce sum and token count jointly so
# that ranks with fewer tokens don't get equal weight (averaging per-rank means would
# be biased when completion lengths differ across ranks).
entropy_stats = self.accelerator.reduce(
torch.stack([(entropies * mask).sum(), mask.sum()]).detach(), reduction="sum"
)
world_entropy = (entropy_stats[0] / entropy_stats[1].clamp(min=1.0)).item()
if self.use_adaptive_entropy:
# Update coefficient and cache entropy once per optimizer step, not per micro-batch.
# apply_coef uses the cached value so all micro-batches within one accumulation
# window apply the same bonus (using per-micro-batch world_entropy would cause
# the bonus to toggle on/off unpredictably across accumulation steps).
if self.accelerator.sync_gradients:
if world_entropy <= self.args.entropy_target:
self.entropy_coef = min(
self.entropy_coef + self.args.entropy_coef_delta, self.args.entropy_coef_max
)
else:
self.entropy_coef = max(
self.entropy_coef - self.args.entropy_coef_delta, self.args.entropy_coef_min
)
self._last_world_entropy = world_entropy
apply_coef = self.entropy_coef if self._last_world_entropy <= self.args.entropy_target else 0.0
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.
else:
apply_coef = self.entropy_coef

loss = loss - apply_coef * entropy_loss
Comment thread
cursor[bot] marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.

# The policy loss above is scaled for gradient accumulation (HF auto-scaling is off here), so scale aux too
if self.aux_loss_enabled:
normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0
Expand All @@ -2751,6 +2804,11 @@ def masked_batch_mean(x):
else:
return (x * mask).sum() / completion_token_count

self._metrics[mode]["policy_loss"].append(self.accelerator.reduce(policy_loss, reduction="mean").item())

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

policy_loss is captured after loss = loss / normalizer, where normalizer = current_gradient_accumulation_steps, so the logged policy_loss is the per-micro-batch contribution, not the step loss; it'll read ~accum× too small. I think we should capture before dividing.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another remark, it can be misleading to have policy_loss logged when it's not used in the loss. Maybe we should gate its logging.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and just for matching home-style:

self._metrics[mode]["policy_loss"].append(self.accelerator.gather(policy_loss).nanmean().item())

if self.entropy_coef != 0.0 or self.use_adaptive_entropy:
self._metrics[mode]["entropy_loss"].append(world_entropy)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would have two near-duplicate entropy metrics. We could drop it, and just log entropy. The existing entropy already logs mean entropy; entropy_loss is just a slightly different (global vs gathered-local-mean) computation of the same quantity.

self._metrics[mode]["entropy_coef"].append(self.entropy_coef)

if self.beta != 0.0:
mean_kl = masked_batch_mean(per_token_kl)
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())
Expand Down Expand Up @@ -2885,3 +2943,16 @@ def _save_checkpoint(self, model, trial):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
if self.use_adaptive_entropy and self.is_world_process_zero():
checkpoint_folder = f"checkpoint-{self.state.global_step}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prefer PREFIX_CHECKPOINT_DIR from transformers.trainer_utils so this stays correct if HF ever renames it.

output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
with open(os.path.join(output_dir, "entropy_ctrl_state.json"), "w") as f:
json.dump({"entropy_coef": self.entropy_coef}, f)

def _load_optimizer_and_scheduler(self, checkpoint):
super()._load_optimizer_and_scheduler(checkpoint)
if self.use_adaptive_entropy and checkpoint is not None:
path = os.path.join(checkpoint, "entropy_ctrl_state.json")
if os.path.exists(path):
with open(path) as f:
self.entropy_coef = json.load(f)["entropy_coef"]
Loading