-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add entropy regularization to GRPO #6140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
ac50a11
dcaaf67
0f6306e
9b1cc65
e944713
f47d5a5
2484e70
4507747
9b70a4a
9d79e4a
46c8a64
3f7a669
a05c979
fe03dd1
f099349
5288cd5
03f4208
dbc0c75
391da7a
506fbf9
8a6b53d
474b30c
a0b9ec6
608b1e0
81841ad
bee5126
5c442a0
2f34d15
806078d
2845ef4
2ed11c0
7f0562b
76255d3
fc76d4b
bed5188
6e8f498
607d911
0cfad37
8e05132
f15e04a
bccd8eb
0f3e145
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import copy | ||
| import importlib.resources as pkg_resources | ||
| import inspect | ||
| import json | ||
| import math | ||
| import os | ||
| import sys | ||
|
|
@@ -667,6 +668,10 @@ def __init__( | |
| f"Unknown importance sampling level: {self.importance_sampling_level}. " | ||
| "Possible values are 'token' and 'sequence'." | ||
| ) | ||
| self.entropy_coef = args.entropy_coef | ||
| self.use_adaptive_entropy = args.use_adaptive_entropy | ||
| if self.use_liger_kernel and self.entropy_coef != 0.0: | ||
| raise NotImplementedError("Entropy bonus is not supported with Liger kernel.") | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| # Datasets | ||
| self.shuffle_dataset = args.shuffle_dataset | ||
|
|
@@ -2736,6 +2741,33 @@ def _compute_loss(self, model, inputs): | |
| else: | ||
| raise ValueError(f"Unknown loss type: {self.loss_type}") | ||
|
|
||
| # Entropy bonus: add entropy regularization to encourage exploration | ||
| if self.entropy_coef != 0.0: | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
|
||
| if self.loss_type in ["grpo", "sapo", "luspo"]: | ||
| entropy_loss = ((entropies * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() / normalizer | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The entropy block lumps luspo with grpo/sapo, but luspo's actual loss is entropy_loss = (entropies * mask).sum(-1).mean() / normalizer |
||
| elif self.loss_type == "bnpo": | ||
| entropy_loss = (entropies * mask).sum() / mask.sum().clamp(min=1.0) / normalizer | ||
| elif self.loss_type == "dr_grpo": | ||
| entropy_loss = (entropies * mask).sum() / (entropies.size(0) * self.max_completion_length) / normalizer | ||
| elif self.loss_type in ["cispo", "dapo", "vespo"]: | ||
| entropy_loss = (entropies * mask).sum() / normalizer | ||
|
|
||
| world_entropy = self.accelerator.reduce(entropy_loss.detach(), reduction="mean").item() | ||
| if self.use_adaptive_entropy: | ||
| if world_entropy < self.args.entropy_target: | ||
| self.entropy_coef = min( | ||
| self.entropy_coef + self.args.entropy_coef_delta, self.args.entropy_coef_max | ||
| ) | ||
| else: | ||
| self.entropy_coef = max( | ||
| self.entropy_coef - self.args.entropy_coef_delta, self.args.entropy_coef_min | ||
| ) | ||
|
cursor[bot] marked this conversation as resolved.
|
||
| apply_coef = self.entropy_coef if world_entropy <= self.args.entropy_target else 0.0 | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
cursor[bot] marked this conversation as resolved.
Outdated
|
||
| else: | ||
| apply_coef = self.entropy_coef | ||
|
|
||
| loss = loss - apply_coef * entropy_loss | ||
|
cursor[bot] marked this conversation as resolved.
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| # The policy loss above is scaled for gradient accumulation (HF auto-scaling is off here), so scale aux too | ||
| if self.aux_loss_enabled: | ||
| normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 | ||
|
|
@@ -2751,6 +2783,11 @@ def masked_batch_mean(x): | |
| else: | ||
| return (x * mask).sum() / completion_token_count | ||
|
|
||
| self._metrics[mode]["policy_loss"].append(self.accelerator.reduce(loss.detach(), reduction="mean").item()) | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
|
||
| if self.entropy_coef != 0.0: | ||
| self._metrics[mode]["entropy_loss"].append(world_entropy) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would have two near-duplicate entropy metrics. We could drop it, and just log |
||
| self._metrics[mode]["entropy_coef"].append(self.entropy_coef) | ||
|
|
||
| if self.beta != 0.0: | ||
| mean_kl = masked_batch_mean(per_token_kl) | ||
| self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) | ||
|
|
@@ -2885,3 +2922,16 @@ def _save_checkpoint(self, model, trial): | |
| model_name = self.args.hub_model_id.split("/")[-1] | ||
| self.create_model_card(model_name=model_name) | ||
| super()._save_checkpoint(model, trial) | ||
| if self.use_adaptive_entropy and self.is_world_process_zero(): | ||
| checkpoint_folder = f"checkpoint-{self.state.global_step}" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer |
||
| output_dir = os.path.join(self.args.output_dir, checkpoint_folder) | ||
| with open(os.path.join(output_dir, "entropy_ctrl_state.json"), "w") as f: | ||
| json.dump({"entropy_coef": self.entropy_coef}, f) | ||
|
|
||
| def _load_optimizer_and_scheduler(self, checkpoint): | ||
| super()._load_optimizer_and_scheduler(checkpoint) | ||
| if self.use_adaptive_entropy and checkpoint is not None: | ||
| path = os.path.join(checkpoint, "entropy_ctrl_state.json") | ||
| if os.path.exists(path): | ||
| with open(path) as f: | ||
| self.entropy_coef = json.load(f)["entropy_coef"] | ||
Uh oh!
There was an error while loading. Please reload this page.