diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 0621d5ee689..b1e0050e1a6 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -185,7 +185,9 @@ While training and evaluating, we record the following reward metrics: - `reward`: The overall average reward after summing rewards across functions (weighted by `reward_weights`). - `reward_std`: The standard deviation of summed rewards across functions (weighted by `reward_weights`), computed over the full batch. - `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect). +- `policy_loss`: The policy gradient loss value (before any entropy bonus). Logged when `entropy_coef` is nonzero or `use_adaptive_entropy=True`. - `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.) +- `entropy_coef`: The current entropy regularization coefficient. Logged when `entropy_coef` is nonzero or `use_adaptive_entropy=True`. Updated once per optimizer step when `use_adaptive_entropy=True`. - `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero. - `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region: \\( \text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \quad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \\). A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change. - `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\). @@ -641,6 +643,46 @@ and the reward will be computed as the sum of the rewards from each function, or Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. +### Entropy regularization + +To encourage exploration and prevent the policy from collapsing to near-deterministic outputs, you can add an entropy bonus to the training objective. The entropy regularization augments the GRPO loss as follows: + +$$ +\mathcal{L}(\theta) = \mathcal{L}_{\text{GRPO}}(\theta) - \alpha \cdot \mathcal{H}(\pi_\theta), +$$ + +where \\(\mathcal{H}(\pi_\theta)\\) is the mean per-token entropy of the policy and \\(\alpha\\) is the entropy coefficient. The bonus is always the mean per-token entropy regardless of `loss_type`; it is not rescaled to match a loss type's policy normalization (e.g. Dr. GRPO's `batch_size * max_completion_length` denominator), so `entropy_coef` has the same meaning for every loss type. + +**Static entropy** — a fixed coefficient throughout training: + +```python +from trl import GRPOConfig, GRPOTrainer + +training_args = GRPOConfig(entropy_coef=0.05, ...) +``` + +**Adaptive entropy** — the coefficient is updated each optimizer step based on a target entropy, as introduced in [Skywork-OR1](https://huggingface.co/papers/2505.22312). When the current entropy falls at or below `entropy_target`, the coefficient is incremented by `entropy_coef_delta`; otherwise it is decremented. The coefficient is only applied (i.e. non-zero) while entropy is at or below the target: + +```python +training_args = GRPOConfig( + entropy_coef=0.01, # initial coefficient + use_adaptive_entropy=True, + entropy_target=5.0, # target mean per-token entropy (nats); tune for your model + entropy_coef_delta=0.005, # step size per optimizer step + entropy_coef_min=0.0, + entropy_coef_max=1.0, + ... +) +``` + + + +Typical language models have per-token entropies of 2–10 nats, so the default `entropy_target=0.2` almost never triggers regularization — the bonus only engages once entropy is at or below the target, i.e. near-complete collapse. Set it to a value meaningful for your model, e.g. close to the entropy you observe early in training (logged as the `entropy` metric). When using `top_entropy_quantile < 1.0`, `entropy_target` applies to the high-entropy token subset — that subset's entropy will be higher than the logged full-token `entropy`, so calibrate accordingly. + + + +When `use_adaptive_entropy=True`, the current entropy coefficient `entropy_coef` is saved alongside each checkpoint and restored on resume, so training is fully resumable. + ### Rapid Experimentation for GRPO RapidFire AI is an open-source experimentation engine that sits on top of TRL and lets you launch multiple GRPO configurations at once, even on a single GPU. Instead of trying configurations sequentially, RapidFire lets you **see all their learning curves earlier, stop underperforming runs, and clone promising ones with new settings in flight** without restarting. For more information, see [RapidFire AI Integration](rapidfire_integration). diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 533bb98a8b4..6234623743a 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -232,6 +232,27 @@ training_args = GRPOConfig( ) ``` +### Skywork-OR1: Open Reasoning Models + +**📜 Paper**: https://huggingface.co/papers/2505.22312 + +Skywork-OR1 is a family of open reasoning models trained with GRPO. The paper introduces **adaptive entropy control**: an entropy regularization term `−α·H(π_θ)` is added to the GRPO objective, and the coefficient `α` is automatically adjusted each optimizer step. When the model's mean per-token entropy falls at or below a target, `α` is incremented to encourage more exploration; otherwise it is decremented. The bonus is only applied while entropy is at or below the target. To replicate this adaptive entropy control, use the following configuration: + +```python +from trl import GRPOConfig, GRPOTrainer + +training_args = GRPOConfig( + use_adaptive_entropy=True, # enable adaptive entropy control (Section 3.3 of the paper) + entropy_coef=0.01, # initial entropy regularization coefficient + entropy_target=5.0, # target mean per-token entropy (nats); tune for your model + entropy_coef_delta=0.005, # step size for coefficient updates per optimizer step +) +trainer = GRPOTrainer( + ..., + args=training_args, +) +``` + ### Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning **📜 Paper**: https://huggingface.co/papers/2506.01939 diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 98d61d581fa..875ee9b3b78 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1474,6 +1474,154 @@ def test_train_with_cast_lm_head_to_fp32(self, model_name): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_train_with_static_entropy(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + entropy_coef=0.1, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["policy_loss"] is not None + assert trainer.state.log_history[-1]["entropy_coef"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + def test_train_with_adaptive_entropy(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + entropy_coef=0.01, + use_adaptive_entropy=True, + entropy_target=15.0, # above any realistic entropy → coef is always incremented + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["policy_loss"] is not None + assert trainer.state.log_history[-1]["entropy_coef"] is not None + # Coefficient should have increased since entropy < target throughout training + assert trainer.entropy_coef > 0.01 + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + + @pytest.mark.parametrize("loss_type", ["grpo", "dr_grpo", "dapo", "luspo"]) + def test_entropy_bonus_scale(self, loss_type): + # Regression test: the entropy bonus is the mean per-token entropy H for every loss type (documented + # objective L = L_policy - entropy_coef * H), so it must not inherit any loss-type-specific policy + # normalization. A previous "unified" formula divided H by a global token count for the + # cispo/dapo/vespo family, making the bonus ~1/sequence_length too small; conversely, scaling the + # bonus like the dr_grpo (fixed budget) or luspo (sequence-weighted) policy term would also be wrong. + # With gradient_accumulation_steps=1 the per-step entropy contribution to the loss is + # contrib = policy_loss - loss = entropy_coef * entropy_loss, so contrib / entropy must equal + # entropy_coef for all loss types. + entropy_coef = 0.5 + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + importance_sampling_level="sequence" if loss_type == "luspo" else "token", + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=16, # reduce the completion length to reduce memory usage + gradient_accumulation_steps=1, # so contrib == entropy_coef * entropy_loss holds per step + loss_type=loss_type, + logging_steps=1, + report_to="none", + entropy_coef=entropy_coef, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + trainer.train() + + logs = [h for h in trainer.state.log_history if "policy_loss" in h and "loss" in h and h.get("entropy")] + assert logs + ratios = sorted((h["policy_loss"] - h["loss"]) / h["entropy"] for h in logs) + ratio = ratios[len(ratios) // 2] # median, robust to per-step noise + # Every loss type regularizes the mean per-token entropy, so contrib == entropy_coef * entropy. + assert ratio == pytest.approx(entropy_coef, rel=0.3) + + def test_train_with_adaptive_entropy_gradient_accumulation(self): + # Adaptive entropy must behave correctly under gradient accumulation: the coefficient and gating are + # frozen across an accumulation window and the controller updates once per optimizer step (not once + # per micro-batch). With entropy_target above any realistic entropy the coefficient is incremented by + # entropy_coef_delta on every optimizer step, so the final value pins down the number of updates. + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + gradient_accumulation_steps=2, # exercise the accumulation window + report_to="none", + entropy_coef=0.01, + use_adaptive_entropy=True, + entropy_target=15.0, # above any realistic entropy → coef incremented once per optimizer step + entropy_coef_delta=0.005, + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + # Exactly one increment per optimizer step (global_step counts optimizer steps, not micro-batches); + # a per-micro-batch update would overshoot this. + expected_coef = min(0.01 + 0.005 * trainer.state.global_step, 1.0) + assert trainer.entropy_coef == pytest.approx(expected_coef, abs=1e-6) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_train_with_entropy_filter(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") training_args = GRPOConfig( diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 5736b6c0ddc..46fce0e13a9 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -295,6 +295,34 @@ class GRPOConfig(_BaseConfig): position, improving results. Range: `[0.0-1.0]`. A value of `0.0` masks all but the highest entropy token; `1.0` keeps all tokens. The paper recommends a value of `0.2`. If used with `mask_truncated_completions=True`, only tokens from non-truncated completions are considered. + entropy_coef (`float`, *optional*, defaults to `0.0`): + Coefficient of the entropy regularization term in the loss. A positive value adds an entropy bonus that + encourages exploration by keeping the policy from collapsing to near-deterministic outputs. The bonus is + always the mean per-token entropy regardless of `loss_type`; it is not rescaled to match a loss type's + policy normalization, so `entropy_coef` has the same meaning for every loss type. When + `use_adaptive_entropy=True`, this serves as the initial coefficient and is updated each optimizer step. + Has no effect when set to `0.0` (default). + use_adaptive_entropy (`bool`, *optional*, defaults to `False`): + Whether to use adaptive entropy control, introduced in + [Skywork-OR1](https://huggingface.co/papers/2505.22312). When enabled, the entropy coefficient + `entropy_coef` is updated each optimizer step: incremented by `entropy_coef_delta` when the current + entropy is below `entropy_target`, and decremented otherwise. The coefficient is only applied when + entropy is at or below `entropy_target`. + entropy_coef_min (`float`, *optional*, defaults to `0.0`): + Lower bound for the entropy coefficient when using adaptive entropy control. + entropy_coef_max (`float`, *optional*, defaults to `1.0`): + Upper bound for the entropy coefficient when using adaptive entropy control. + entropy_coef_delta (`float`, *optional*, defaults to `0.005`): + Step size for adjusting the entropy coefficient at each optimizer step during adaptive entropy control. + entropy_target (`float`, *optional*, defaults to `0.2`): + Target mean per-token entropy (in nats) used by adaptive entropy control. The coefficient is only + applied when the current entropy falls at or below this value. Measured over the same token set as + the policy loss: all completion tokens by default, or only the high-entropy subset when + `top_entropy_quantile < 1.0`. Typical language models have per-token entropies in the range 2–10 + nats, so the default of `0.2` almost never triggers regularization (only on near-complete entropy + collapse); set it close to the entropy you observe early in training (logged as the `entropy` + metric) so the bonus engages before the policy collapses (and account for the token subset when + using `top_entropy_quantile`). max_tool_calling_iterations (`int`, *optional*): Maximum number of tool-calling turns when training an agent. If `None`, there is no limit and generation stops when the model generates a response turn with no tool calls or when the total response length reaches @@ -832,6 +860,47 @@ class GRPOConfig(_BaseConfig): "non-truncated completions are considered." }, ) + entropy_coef: float = field( + default=0.0, + metadata={ + "help": "Coefficient of the entropy regularization term in the loss. A positive value adds an entropy " + "bonus that encourages exploration. The bonus is always the mean per-token entropy regardless of " + "`loss_type` (not rescaled to a loss type's policy normalization), so `entropy_coef` has the same " + "meaning for every loss type. When `use_adaptive_entropy=True`, this serves as the initial " + "coefficient and is updated each optimizer step. Has no effect when set to `0.0` (default)." + }, + ) + use_adaptive_entropy: bool = field( + default=False, + metadata={ + "help": "Whether to use adaptive entropy control, introduced in Skywork-OR1 " + "(https://huggingface.co/papers/2505.22312). When enabled, `entropy_coef` is incremented by " + "`entropy_coef_delta` when entropy is below `entropy_target`, and decremented otherwise." + }, + ) + entropy_coef_min: float = field( + default=0.0, + metadata={"help": "Lower bound for the entropy coefficient when using adaptive entropy control."}, + ) + entropy_coef_max: float = field( + default=1.0, + metadata={"help": "Upper bound for the entropy coefficient when using adaptive entropy control."}, + ) + entropy_coef_delta: float = field( + default=0.005, + metadata={"help": "Step size for adjusting the entropy coefficient during adaptive entropy control."}, + ) + entropy_target: float = field( + default=0.2, + metadata={ + "help": "Target mean per-token entropy (nats) for adaptive entropy control. The coefficient is only " + "applied when current entropy is at or below this value. Measured over the same token set as the " + "policy loss (all completion tokens, or the high-entropy subset when top_entropy_quantile < 1.0). " + "Typical language models have per-token entropies of 2–10 nats, so the default of 0.2 almost never " + "triggers regularization (only on near-complete collapse); set it close to the entropy observed " + "early in training and tune from there." + }, + ) max_tool_calling_iterations: int | None = field( default=None, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ea4f85d3ae1..ff89013a3e6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -17,6 +17,7 @@ import copy import importlib.resources as pkg_resources import inspect +import json import math import os import sys @@ -667,6 +668,13 @@ def __init__( f"Unknown importance sampling level: {self.importance_sampling_level}. " "Possible values are 'token' and 'sequence'." ) + self.entropy_coef = args.entropy_coef + self.use_adaptive_entropy = args.use_adaptive_entropy + # Cached entropy from the last optimizer step; inf so the first accumulation window + # applies no bonus until a real measurement arrives (conservative default). + self._last_world_entropy = float("inf") + if self.use_liger_kernel and (self.entropy_coef != 0.0 or self.use_adaptive_entropy): + raise NotImplementedError("Entropy bonus is not supported with Liger kernel.") # Datasets self.shuffle_dataset = args.shuffle_dataset @@ -2760,26 +2768,92 @@ def _compute_loss(self, model, inputs): if self.loss_type in ["grpo", "sapo"]: loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + policy_loss = loss.detach() loss = loss / normalizer elif self.loss_type == "bnpo": loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + policy_loss = loss.detach() loss = loss / normalizer elif self.loss_type == "dr_grpo": loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 # no accum in eval + policy_loss = loss.detach() loss = loss / normalizer elif self.loss_type in ["cispo", "dapo", "vespo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * mask).sum() / normalizer + policy_loss = loss.detach() elif self.loss_type == "luspo": # Unless importance_sampling_level="token" (not recommended here), per_token_loss is expected to be (B, 1) loss = (per_token_loss * mask.sum(1, keepdim=True)).mean() normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 + policy_loss = loss.detach() loss = loss / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") + # Entropy bonus: add entropy regularization to encourage exploration. + # Gate: run whenever a non-zero static coef is set OR adaptive mode is enabled. Adaptive must always run even + # when self.entropy_coef has been decremented to entropy_coef_min (default 0) so it can recover once entropy + # drops below entropy_target again. + if self.entropy_coef != 0.0 or self.use_adaptive_entropy: + # When top_entropy_quantile < 1.0, entropy_mask restricts policy gradients to high-entropy + # tokens. Use the same effective mask for the entropy bonus so it acts on the same tokens. + effective_mask = mask if entropy_mask is None else mask * entropy_mask + # Entropy bonus = mean per-token entropy H (the documented objective L = L_policy - coef * H), so + # H does not depend on how each loss type normalizes its policy term. The term is computed so that + # it accumulates to H over the optimizer step for every loss type and matches world_entropy below. + # The only wrinkle is the normalizer: most loss types divide by the gradient accumulation step + # count, but cispo/dapo/vespo divide by a global token count. + if self.loss_type in ["cispo", "dapo", "vespo"]: + # normalizer is a global token count, so summing the entropies (instead of averaging them + # again) makes the term accumulate over the optimizer step to the global mean per-token + # entropy, like the other loss types. + entropy_loss = (entropies * effective_mask).sum() / normalizer + else: + # Mean per-token entropy of active tokens, scaled for gradient accumulation. + entropy_loss = (entropies * effective_mask).sum() / effective_mask.sum().clamp(min=1.0) / normalizer + + # Apply the coefficient and gating from the end of the previous optimizer step, so that every + # micro-batch in the current accumulation window applies the same entropy bonus. The adaptive + # update below only takes effect on the next step. + if self.use_adaptive_entropy: + apply_coef = self.entropy_coef if self._last_world_entropy <= self.args.entropy_target else 0.0 + else: + apply_coef = self.entropy_coef + + loss = loss - apply_coef * entropy_loss + + self._metrics[mode]["policy_loss"].append(self.accelerator.gather(policy_loss).nanmean().item()) + + # Adaptive update: once per optimizer step, measure the global token-weighted entropy and adjust + # the coefficient for the next step. Gated on train mode so evaluation cannot mutate the entropy + # controller state, and on sync_gradients so the all-reduce runs once per optimizer step rather + # than on every micro-batch of the accumulation window. + if self.use_adaptive_entropy and mode == "train" and self.accelerator.sync_gradients: + # Reduce sum and token count jointly for a true global mean (unbiased when ranks have + # different completion lengths). + entropy_stats = self.accelerator.reduce( + torch.stack([(entropies * effective_mask).sum(), effective_mask.sum()]).detach(), + reduction="sum", + ) + world_entropy = (entropy_stats[0] / entropy_stats[1].clamp(min=1.0)).item() + if world_entropy <= self.args.entropy_target: + self.entropy_coef = min( + self.entropy_coef + self.args.entropy_coef_delta, self.args.entropy_coef_max + ) + else: + self.entropy_coef = max( + self.entropy_coef - self.args.entropy_coef_delta, self.args.entropy_coef_min + ) + self._last_world_entropy = world_entropy + + # Log entropy_coef on train optimizer-step boundaries (constant for static control; updated just + # above for adaptive control). sync_gradients is always True in eval (no accumulation context). + if mode == "train" and self.accelerator.sync_gradients: + self._metrics[mode]["entropy_coef"].append(self.entropy_coef) + # The policy loss above is scaled for gradient accumulation (HF auto-scaling is off here), so scale aux too if self.aux_loss_enabled: normalizer = self.current_gradient_accumulation_steps if mode == "train" else 1.0 @@ -2929,3 +3003,18 @@ def _save_checkpoint(self, model, trial): model_name = self.args.hub_model_id.split("/")[-1] self.create_model_card(model_name=model_name) super()._save_checkpoint(model, trial) + if self.use_adaptive_entropy and self.args.should_save: + checkpoint_folder = f"checkpoint-{self.state.global_step}" + output_dir = os.path.join(self._get_output_dir(trial=trial), checkpoint_folder) + with open(os.path.join(output_dir, "entropy_ctrl_state.json"), "w") as f: + json.dump({"entropy_coef": self.entropy_coef, "last_world_entropy": self._last_world_entropy}, f) + + def _load_optimizer_and_scheduler(self, checkpoint): + super()._load_optimizer_and_scheduler(checkpoint) + if self.use_adaptive_entropy and checkpoint is not None: + path = os.path.join(checkpoint, "entropy_ctrl_state.json") + if os.path.exists(path): + with open(path) as f: + state = json.load(f) + self.entropy_coef = state["entropy_coef"] + self._last_world_entropy = state.get("last_world_entropy", float("inf"))