-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add entropy regularization to GRPO #6140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
ac50a11
dcaaf67
0f6306e
9b1cc65
e944713
f47d5a5
2484e70
4507747
9b70a4a
9d79e4a
46c8a64
3f7a669
a05c979
fe03dd1
f099349
5288cd5
03f4208
dbc0c75
391da7a
506fbf9
8a6b53d
474b30c
a0b9ec6
608b1e0
81841ad
bee5126
5c442a0
2f34d15
806078d
2845ef4
2ed11c0
7f0562b
76255d3
fc76d4b
bed5188
6e8f498
607d911
0cfad37
8e05132
f15e04a
bccd8eb
0f3e145
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import copy | ||
| import importlib.resources as pkg_resources | ||
| import inspect | ||
| import json | ||
| import math | ||
| import os | ||
| import sys | ||
|
|
@@ -667,6 +668,10 @@ def __init__( | |
| f"Unknown importance sampling level: {self.importance_sampling_level}. " | ||
| "Possible values are 'token' and 'sequence'." | ||
| ) | ||
| self.entropy_coef = args.entropy_coef | ||
| self.use_adaptive_entropy = args.use_adaptive_entropy | ||
| if self.use_liger_kernel and (self.entropy_coef != 0.0 or self.use_adaptive_entropy): | ||
| raise NotImplementedError("Entropy bonus is not supported with Liger kernel.") | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| # Datasets | ||
| self.shuffle_dataset = args.shuffle_dataset | ||
|
|
@@ -2736,6 +2741,45 @@ def _compute_loss(self, model, inputs): | |
| else: | ||
| raise ValueError(f"Unknown loss type: {self.loss_type}") | ||
|
|
||
| # Capture the pure policy loss for logging before entropy/aux modify it | ||
| policy_loss = loss.detach() | ||
|
|
||
| # Entropy bonus: add entropy regularization to encourage exploration. | ||
| # Gate: run whenever a non-zero static coef is set OR adaptive mode is enabled. Adaptive must always run even | ||
| # when self.entropy_coef has been decremented to entropy_coef_min (default 0) so it can recover once entropy | ||
| # drops below entropy_target again. | ||
| if self.entropy_coef != 0.0 or self.use_adaptive_entropy: | ||
| if self.loss_type in ["grpo", "sapo", "luspo"]: | ||
| entropy_loss = ((entropies * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() / normalizer | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The entropy block lumps luspo with grpo/sapo, but luspo's actual loss is entropy_loss = (entropies * mask).sum(-1).mean() / normalizer |
||
| elif self.loss_type == "bnpo": | ||
| entropy_loss = (entropies * mask).sum() / mask.sum().clamp(min=1.0) / normalizer | ||
| elif self.loss_type == "dr_grpo": | ||
| entropy_loss = (entropies * mask).sum() / (entropies.size(0) * self.max_completion_length) / normalizer | ||
| elif self.loss_type in ["cispo", "dapo", "vespo"]: | ||
| entropy_loss = (entropies * mask).sum() / normalizer | ||
|
|
||
| # Mean per-token entropy in nats across ranks — computed independently of the loss | ||
| # normalizer so its scale matches entropy_target (loss-scaled entropy_loss would not). | ||
| world_entropy = self.accelerator.reduce( | ||
| ((entropies * mask).sum() / mask.sum().clamp(min=1.0)).detach(), reduction="mean" | ||
| ).item() | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
|
||
| if self.use_adaptive_entropy: | ||
| # Update the coefficient once per optimizer step, not per micro-batch | ||
| if self.accelerator.sync_gradients: | ||
| if world_entropy < self.args.entropy_target: | ||
| self.entropy_coef = min( | ||
| self.entropy_coef + self.args.entropy_coef_delta, self.args.entropy_coef_max | ||
| ) | ||
| else: | ||
| self.entropy_coef = max( | ||
| self.entropy_coef - self.args.entropy_coef_delta, self.args.entropy_coef_min | ||
| ) | ||
| apply_coef = self.entropy_coef if world_entropy <= self.args.entropy_target else 0.0 | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
cursor[bot] marked this conversation as resolved.
Outdated
|
||
| else: | ||
| apply_coef = self.entropy_coef | ||
|
|
||
| loss = loss - apply_coef * entropy_loss | ||
|
cursor[bot] marked this conversation as resolved.
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| # The policy loss above is scaled for gradient accumulation (HF auto-scaling is off here), so scale aux too | ||
| if self.aux_loss_enabled: | ||
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 | ||
|
|
@@ -2751,6 +2795,11 @@ def masked_batch_mean(x): | |
| else: | ||
| return (x * mask).sum() / completion_token_count | ||
|
|
||
| self._metrics[mode]["policy_loss"].append(self.accelerator.reduce(policy_loss, reduction="mean").item()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. policy_loss is captured after
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another remark, it can be misleading to have
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and just for matching home-style: self._metrics[mode]["policy_loss"].append(self.accelerator.gather(policy_loss).nanmean().item()) |
||
| if self.entropy_coef != 0.0 or self.use_adaptive_entropy: | ||
| self._metrics[mode]["entropy_loss"].append(world_entropy) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would have two near-duplicate entropy metrics. We could drop it, and just log |
||
| self._metrics[mode]["entropy_coef"].append(self.entropy_coef) | ||
|
|
||
| if self.beta != 0.0: | ||
| mean_kl = masked_batch_mean(per_token_kl) | ||
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) | ||
|
|
@@ -2885,3 +2934,16 @@ def _save_checkpoint(self, model, trial): | |
| model_name = self.args.hub_model_id.split("/")[-1] | ||
| self.create_model_card(model_name=model_name) | ||
| super()._save_checkpoint(model, trial) | ||
| if self.use_adaptive_entropy and self.is_world_process_zero(): | ||
| checkpoint_folder = f"checkpoint-{self.state.global_step}" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer |
||
| output_dir = os.path.join(self.args.output_dir, checkpoint_folder) | ||
| with open(os.path.join(output_dir, "entropy_ctrl_state.json"), "w") as f: | ||
| json.dump({"entropy_coef": self.entropy_coef}, f) | ||
|
|
||
| def _load_optimizer_and_scheduler(self, checkpoint): | ||
| super()._load_optimizer_and_scheduler(checkpoint) | ||
| if self.use_adaptive_entropy and checkpoint is not None: | ||
| path = os.path.join(checkpoint, "entropy_ctrl_state.json") | ||
| if os.path.exists(path): | ||
| with open(path) as f: | ||
| self.entropy_coef = json.load(f)["entropy_coef"] | ||
Uh oh!
There was an error while loading. Please reload this page.