-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add entropy regularization to GRPO #6140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
ac50a11
dcaaf67
0f6306e
9b1cc65
e944713
f47d5a5
2484e70
4507747
9b70a4a
9d79e4a
46c8a64
3f7a669
a05c979
fe03dd1
f099349
5288cd5
03f4208
dbc0c75
391da7a
506fbf9
8a6b53d
474b30c
a0b9ec6
608b1e0
81841ad
bee5126
5c442a0
2f34d15
806078d
2845ef4
2ed11c0
7f0562b
76255d3
fc76d4b
bed5188
6e8f498
607d911
0cfad37
8e05132
f15e04a
bccd8eb
0f3e145
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import copy | ||
| import importlib.resources as pkg_resources | ||
| import inspect | ||
| import json | ||
| import math | ||
| import os | ||
| import sys | ||
|
|
@@ -667,6 +668,13 @@ def __init__( | |
| f"Unknown importance sampling level: {self.importance_sampling_level}. " | ||
| "Possible values are 'token' and 'sequence'." | ||
| ) | ||
| self.entropy_coef = args.entropy_coef | ||
| self.use_adaptive_entropy = args.use_adaptive_entropy | ||
| # Cached entropy from the last optimizer step; inf so the first accumulation window | ||
| # applies no bonus until a real measurement arrives (conservative default). | ||
| self._last_world_entropy = float("inf") | ||
|
cursor[bot] marked this conversation as resolved.
|
||
| if self.use_liger_kernel and (self.entropy_coef != 0.0 or self.use_adaptive_entropy): | ||
| raise NotImplementedError("Entropy bonus is not supported with Liger kernel.") | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| # Datasets | ||
| self.shuffle_dataset = args.shuffle_dataset | ||
|
|
@@ -2736,6 +2744,51 @@ def _compute_loss(self, model, inputs): | |
| else: | ||
| raise ValueError(f"Unknown loss type: {self.loss_type}") | ||
|
|
||
| # Capture the pure policy loss for logging before entropy/aux modify it | ||
| policy_loss = loss.detach() | ||
|
|
||
| # Entropy bonus: add entropy regularization to encourage exploration. | ||
| # Gate: run whenever a non-zero static coef is set OR adaptive mode is enabled. Adaptive must always run even | ||
| # when self.entropy_coef has been decremented to entropy_coef_min (default 0) so it can recover once entropy | ||
| # drops below entropy_target again. | ||
| if self.entropy_coef != 0.0 or self.use_adaptive_entropy: | ||
| if self.loss_type in ["grpo", "sapo", "luspo"]: | ||
| entropy_loss = ((entropies * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() / normalizer | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The entropy block lumps luspo with grpo/sapo, but luspo's actual loss is entropy_loss = (entropies * mask).sum(-1).mean() / normalizer |
||
| elif self.loss_type == "bnpo": | ||
| entropy_loss = (entropies * mask).sum() / mask.sum().clamp(min=1.0) / normalizer | ||
| elif self.loss_type == "dr_grpo": | ||
| entropy_loss = (entropies * mask).sum() / (entropies.size(0) * self.max_completion_length) / normalizer | ||
| elif self.loss_type in ["cispo", "dapo", "vespo"]: | ||
| entropy_loss = (entropies * mask).sum() / normalizer | ||
|
|
||
| # True global mean per-token entropy (nats): reduce sum and token count jointly so | ||
| # that ranks with fewer tokens don't get equal weight (averaging per-rank means would | ||
| # be biased when completion lengths differ across ranks). | ||
| entropy_stats = self.accelerator.reduce( | ||
| torch.stack([(entropies * mask).sum(), mask.sum()]).detach(), reduction="sum" | ||
| ) | ||
| world_entropy = (entropy_stats[0] / entropy_stats[1].clamp(min=1.0)).item() | ||
| if self.use_adaptive_entropy: | ||
| # Update coefficient and cache entropy once per optimizer step, not per micro-batch. | ||
| # apply_coef uses the cached value so all micro-batches within one accumulation | ||
| # window apply the same bonus (using per-micro-batch world_entropy would cause | ||
| # the bonus to toggle on/off unpredictably across accumulation steps). | ||
| if self.accelerator.sync_gradients: | ||
| if world_entropy <= self.args.entropy_target: | ||
| self.entropy_coef = min( | ||
| self.entropy_coef + self.args.entropy_coef_delta, self.args.entropy_coef_max | ||
| ) | ||
| else: | ||
| self.entropy_coef = max( | ||
| self.entropy_coef - self.args.entropy_coef_delta, self.args.entropy_coef_min | ||
| ) | ||
| self._last_world_entropy = world_entropy | ||
| apply_coef = self.entropy_coef if self._last_world_entropy <= self.args.entropy_target else 0.0 | ||
|
cursor[bot] marked this conversation as resolved.
cursor[bot] marked this conversation as resolved.
|
||
| else: | ||
| apply_coef = self.entropy_coef | ||
|
|
||
| loss = loss - apply_coef * entropy_loss | ||
|
cursor[bot] marked this conversation as resolved.
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| # The policy loss above is scaled for gradient accumulation (HF auto-scaling is off here), so scale aux too | ||
| if self.aux_loss_enabled: | ||
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 | ||
|
|
@@ -2751,6 +2804,11 @@ def masked_batch_mean(x): | |
| else: | ||
| return (x * mask).sum() / completion_token_count | ||
|
|
||
| self._metrics[mode]["policy_loss"].append(self.accelerator.reduce(policy_loss, reduction="mean").item()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. policy_loss is captured after
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another remark, it can be misleading to have
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and just for matching home-style: self._metrics[mode]["policy_loss"].append(self.accelerator.gather(policy_loss).nanmean().item()) |
||
| if self.entropy_coef != 0.0 or self.use_adaptive_entropy: | ||
| self._metrics[mode]["entropy_loss"].append(world_entropy) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would have two near-duplicate entropy metrics. We could drop it, and just log |
||
| self._metrics[mode]["entropy_coef"].append(self.entropy_coef) | ||
|
|
||
| if self.beta != 0.0: | ||
| mean_kl = masked_batch_mean(per_token_kl) | ||
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) | ||
|
|
@@ -2885,3 +2943,16 @@ def _save_checkpoint(self, model, trial): | |
| model_name = self.args.hub_model_id.split("/")[-1] | ||
| self.create_model_card(model_name=model_name) | ||
| super()._save_checkpoint(model, trial) | ||
| if self.use_adaptive_entropy and self.is_world_process_zero(): | ||
| checkpoint_folder = f"checkpoint-{self.state.global_step}" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer |
||
| output_dir = os.path.join(self.args.output_dir, checkpoint_folder) | ||
| with open(os.path.join(output_dir, "entropy_ctrl_state.json"), "w") as f: | ||
| json.dump({"entropy_coef": self.entropy_coef}, f) | ||
|
|
||
| def _load_optimizer_and_scheduler(self, checkpoint): | ||
| super()._load_optimizer_and_scheduler(checkpoint) | ||
| if self.use_adaptive_entropy and checkpoint is not None: | ||
| path = os.path.join(checkpoint, "entropy_ctrl_state.json") | ||
| if os.path.exists(path): | ||
| with open(path) as f: | ||
| self.entropy_coef = json.load(f)["entropy_coef"] | ||
Uh oh!
There was an error while loading. Please reload this page.